From b162eea9d6c32b9f13df85ba8793850d67976ae9 Mon Sep 17 00:00:00 2001 From: Yasser Tahiri Date: Sun, 13 Nov 2022 17:49:36 +0400 Subject: [PATCH 1/2] Improve Codebase --- .github/dependabot.yaml | 10 +++ .github/workflows/build.yaml | 35 ++++++++ .gitignore | 134 ++++++++++++++++++++++++++++-- .pre-commit-config.yaml | 44 ++++++++++ README.md | 2 +- jsql/__init__.py | 139 ++++++++++++++++++++++++-------- pyproject.toml | 57 +++++++++++++ scripts/clean.sh | 17 ++++ scripts/format.sh | 6 ++ scripts/test.sh | 9 +++ setup.py | 22 ----- tests/__init__.py | 0 tests/test_list_param.py | 48 ++++++++--- tests/test_render.py | 107 +++++++++++++++--------- tests/test_sql.py | 34 +++++--- tests/test_sql_proxy_factory.py | 48 +++++++++++ tests/test_sqlproxy.py | 50 +++++++----- tests/test_version.py | 5 ++ 18 files changed, 626 insertions(+), 141 deletions(-) create mode 100644 .github/dependabot.yaml create mode 100644 .github/workflows/build.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml create mode 100644 scripts/clean.sh create mode 100644 scripts/format.sh create mode 100644 scripts/test.sh delete mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/test_sql_proxy_factory.py create mode 100644 tests/test_version.py diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..078bcdc --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,10 @@ +version: 2 + +updates: + # GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + commit-message: + prefix: ⬆ \ No newline at end of file diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..4b862bf --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,35 @@ +name: Test Suite + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.7","3.8","3.9","3.10", "3.11"] + fail-fast: false + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - uses: actions/cache@v3 + id: cache + with: + path: ${{ env.pythonLocation }} + key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-test + - name: Install Dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: pip install -e ."[test, lint]" + - name: Test + run: bash scripts/test.sh + - name: Lint + run: bash scripts/format.sh \ No newline at end of file diff --git a/.gitignore b/.gitignore index d91b7ff..b6e4761 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,129 @@ -.*.swp -*.egg-info -*.pyc +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* .cache -.py* -dist -venv +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4ea8459 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,44 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-merge-conflict + - id: check-added-large-files + - id: check-ast + - id: check-symlinks + - id: trailing-whitespace + - id: check-json + - id: debug-statements + - id: pretty-format-json + args: ["--autofix", "--allow-missing-credentials"] + - repo: https://github.com/PyCQA/isort + rev: 5.10.1 + hooks: + - id: isort + args: ["--profile", "black"] + - repo: https://github.com/PyCQA/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + additional_dependencies: [flake8-print] + files: '\.py$' + exclude: docs/ + args: + - --select=F403,F406,F821,T003 + - repo: https://github.com/humitos/mirrors-autoflake + rev: v1.3 + hooks: + - id: autoflake + files: '\.py$' + exclude: '^\..*' + args: ["--in-place"] + - repo: https://github.com/psf/black + rev: 22.8.0 + hooks: + - id: black + args: ["--target-version", "py39"] + - repo: https://github.com/asottile/pyupgrade + rev: v2.37.3 + hooks: + - id: pyupgrade + args: [--py37-plus] diff --git a/README.md b/README.md index f8334e1..a5ee949 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ pip install jsql==0.7 ## Usage -Check tests for examples. +Check tests for examples. ```python from jsql import sql diff --git a/jsql/__init__.py b/jsql/__init__.py index bb9f7fd..c9f4348 100644 --- a/jsql/__init__.py +++ b/jsql/__init__.py @@ -1,141 +1,174 @@ +import collections +import itertools +import logging +import re + import jinja2 import jinja2.ext -from jinja2.lexer import Token -from jinja2.utils import Markup -import re -import logging import six -import itertools, collections +from jinja2.lexer import Token + +__version__ = "0.7" + class UnsafeSqlException(Exception): pass -NOT_DANGEROUS_RE = re.compile('^[A-Za-z0-9_]*$') + +NOT_DANGEROUS_RE = re.compile("^[A-Za-z0-9_]*$") + + def is_safe(value): return NOT_DANGEROUS_RE.match(value) -@six.python_2_unicode_compatible -class DangerouslyInjectedSql(object): + +class DangerouslyInjectedSql: def __init__(self, value): self.value = value def __str__(self): return self.value + def sql(engine, template, **params): return sql_inner(engine, template, params) + def sql_inner(engine, template, params): query = render(template, params) query, params = format_query_with_list_params(query, params) return SqlProxy(execute_sql(engine, query, params)) + sql_inner_original = sql_inner + def render(template, params): - params['bindparam'] = params.get('bindparam', gen_bindparam(params)) + params["bindparam"] = params.get("bindparam", gen_bindparam(params)) return jenv.from_string(template).render(**params) -logger = logging.getLogger('jsql') + +logger = logging.getLogger("jsql") + def assert_safe_filter(value): if value is None: return None if isinstance(value, DangerouslyInjectedSql): return value - value = six.text_type(value) + value = str(value) if not is_safe(value): - raise UnsafeSqlException('unsafe sql param "{}"'.format(value)) + raise UnsafeSqlException(f'unsafe sql param "{value}"') return value + class AssertSafeExtension(jinja2.ext.Extension): # based on https://github.com/pallets/jinja/issues/503 def filter_stream(self, stream): for token in stream: - if token.type == 'variable_end': - yield Token(token.lineno, 'rparen', ')') - yield Token(token.lineno, 'pipe', '|') - yield Token(token.lineno, 'name', 'assert_safe') + if token.type == "variable_end": + yield Token(token.lineno, "rparen", ")") + yield Token(token.lineno, "pipe", "|") + yield Token(token.lineno, "name", "assert_safe") yield token - if token.type == 'variable_begin': - yield Token(token.lineno, 'lparen', '(') + if token.type == "variable_begin": + yield Token(token.lineno, "lparen", "(") + -jenv = jinja2.Environment(autoescape=False, - extensions=(AssertSafeExtension,)) +jenv = jinja2.Environment(autoescape=False, extensions=(AssertSafeExtension,)) jenv.filters["assert_safe"] = assert_safe_filter + def dangerously_inject_sql(value): return DangerouslyInjectedSql(value) + jenv.filters["dangerously_inject_sql"] = dangerously_inject_sql jenv.globals["comma"] = DangerouslyInjectedSql(",") def execute_sql(engine, query, params): from sqlalchemy.sql import text + q = text(query) - is_session = 'session' in repr(engine.__class__).lower() - return engine.execute(q, params=params) if is_session else engine.execute(q, **params) + is_session = "session" in repr(engine.__class__).lower() + return ( + engine.execute(q, params=params) if is_session else engine.execute(q, **params) + ) + + +BINDPARAM_PREFIX = "bp" + -BINDPARAM_PREFIX = 'bp' def gen_bindparam(params): keygen = key_generator() + def bindparam(val): key = keygen(BINDPARAM_PREFIX) while key in params: key = keygen(BINDPARAM_PREFIX) params[key] = val return key + return bindparam + def key_generator(): keycnt = collections.defaultdict(itertools.count) + def gen_key(key): return key + str(next(keycnt[key])) + return gen_key + def get_param_keys(query): import re + return set(re.findall("(?P:[a-zA-Z_]+_list)", query)) + def format_query_with_list_params(query, params): keys = get_param_keys(query) for key in keys: - if key.endswith('_tuple_list'): + if key.endswith("_tuple_list"): query, params = _format_query_tuple_list_key(key, query, params) else: query, params = _format_query_list_key(key, query, params) return query, params + def _format_query_list_key(key, query, params): values = params.pop(key[1:]) new_keys = [] for i, value in enumerate(values): - new_key = '{}_{}'.format(key, i) + new_key = f"{key}_{i}" new_keys.append(new_key) params[new_key[1:]] = value new_keys_str = ", ".join(new_keys) or "null" - query = query.replace(key, "({})".format(new_keys_str)) + query = query.replace(key, f"({new_keys_str})") return query, params + def _format_query_tuple_list_key(key, query, params): values = params.pop(key[1:]) new_keys = [] for i, value in enumerate(values): - new_key = '{}_{}'.format(key, i) + new_key = f"{key}_{i}" assert isinstance(value, tuple) new_keys2 = [] for i, tuple_val in enumerate(value): - new_key2 = '{}_{}'.format(new_key, i) + new_key2 = f"{new_key}_{i}" new_keys2.append(new_key2) params[new_key2[1:]] = tuple_val - new_keys.append("({})".format(", ".join(new_keys2))) + new_keys.append(f'({", ".join(new_keys2)})') new_keys_str = ", ".join(new_keys) or "null" - query = query.replace(key, "({})".format(new_keys_str)) + query = query.replace(key, f"({new_keys_str})") return query, params -class ObjProxy(object): + +class ObjProxy: def __init__(self, proxied): self._proxied = proxied @@ -147,18 +180,19 @@ def __getattr__(self, attr): return getattr(self, attr) return getattr(self._proxied, attr) + class SqlProxy(ObjProxy): def dicts_iter(self, dict=dict): result = self._proxied keys = result.keys() for r in result: - yield dict((k, v) for k, v in zip(keys, r)) + yield dict(zip(keys, r)) def pk_map_iter(self, dict=dict): result = self._proxied keys = result.keys() for r in result: - yield (r[0], dict((k, v) for k, v in zip(keys, r))) + yield (r[0], dict(zip(keys, r))) def kv_map_iter(self): result = self._proxied @@ -191,3 +225,42 @@ def dict(self, dict=dict): except IndexError: return None + +class SqlProxyFactory: + """ + This class is used to create a SqlProxy object that wraps the result of a query. + """ + + def __init__(self, engine): + self.engine = engine + + def __call__(self, query, params): + query, params = format_query_with_list_params(query, params) + bindparam = gen_bindparam(params) + query = jenv.from_string(query).render(bindparam=bindparam) + result = execute_sql(self.engine, query, params) + return SqlProxy(result) + + +def get_sql_proxy_factory(engine): + """ + This function is used to get a SqlProxyFactory object that can be used to create + SqlProxy objects that wrap the result of a query. + """ + return SqlProxyFactory(engine) + + +def get_sql_proxy(engine, query, params): + """ + This function is used to get a SqlProxy object that wraps the result of a query. + """ + return get_sql_proxy_factory(engine)(query, params) + + +def get_sql_proxy_from_template(engine, template, params): + """ + This function is used to get a SqlProxy object that wraps the result of a query + from a template. + """ + query = jenv.get_template(template).render(**params) + return get_sql_proxy(engine, query, params) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..357d457 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "jsql" +description = 'Lightweight wrapper around sqlalchemy + jinja2.' +readme = "README.md" +requires-python = ">=3.7" +license = "MIT" +keywords = [ + "sql", + "sqlalchemy", + "jinja2", + "sql-template", +] +authors = [ + { name = "Hisham Zarka", email = "hzarka@gmail.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "six", + "jinja2", + "sqlalchemy", +] +dynamic = ["version"] + +[project.optional-dependencies] +lint = [ + "pre-commit", +] +test = [ + "pytest", + "pytest-asyncio", + "pytest-cov", +] + + +[project.urls] +Documentation = "https://github.com/hzarka/python-jsql#readme" +Issues = "https://github.com/hzarka/python-jsql/issues" +Source = "https://github.com/hzarka/python-jsql" + +[tool.hatch.version] +path = "jsql/__init__.py" + + diff --git a/scripts/clean.sh b/scripts/clean.sh new file mode 100644 index 0000000..68df2c7 --- /dev/null +++ b/scripts/clean.sh @@ -0,0 +1,17 @@ +#!/bin/sh -e + +rm -f `find . -type f -name '*.py[co]' ` +rm -f `find . -type f -name '*~' ` +rm -f `find . -type f -name '.*~' ` +rm -f `find . -type f -name .coverage` +rm -f `find . -type f -name ".coverage.*"` +rm -rf `find . -name __pycache__` +rm -rf `find . -type d -name '*.egg-info' ` +rm -rf `find . -type d -name 'pip-wheel-metadata' ` +rm -rf `find . -type d -name .pytest_cache` +rm -rf `find . -type d -name .cache` +rm -rf `find . -type d -name .mypy_cache` +rm -rf `find . -type d -name htmlcov` +rm -rf `find . -type d -name "*.egg-info"` +rm -rf `find . -type d -name build` +rm -rf `find . -type d -name dist` \ No newline at end of file diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100644 index 0000000..12df097 --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e +set -x + +pre-commit run --all-files --verbose --show-diff-on-failure \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000..3a7d630 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +set -e +set -x + +echo "ENV=${ENV}" + +export PYTHONPATH=. +pytest --cov=jsql --cov=tests \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 7918b68..0000000 --- a/setup.py +++ /dev/null @@ -1,22 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name='jsql', - version='0.7', - author='Hisham Zarka', - author_email='hzarka@gmail.com', - packages = find_packages(), - package_dir = {'': '.'}, - requires = ["six"], - install_requires = ["six"], - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - ], - zip_safe=True, -) - diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_list_param.py b/tests/test_list_param.py index 63c7fd6..5776817 100644 --- a/tests/test_list_param.py +++ b/tests/test_list_param.py @@ -1,24 +1,46 @@ -import jsql import pytest +import jsql + + def test_simple_list_param(): - ids = [1, '2', 3, '100'] - query, params = jsql.format_query_with_list_params('id IN :id_list', dict(id_list=ids)) - assert query == 'id IN (:id_list_0, :id_list_1, :id_list_2, :id_list_3)' - assert params == {'id_list_0': 1, 'id_list_1': '2', 'id_list_2': 3, 'id_list_3': '100'} + ids = [1, "2", 3, "100"] + query, params = jsql.format_query_with_list_params( + "id IN :id_list", dict(id_list=ids) + ) + assert query == "id IN (:id_list_0, :id_list_1, :id_list_2, :id_list_3)" + assert params == { + "id_list_0": 1, + "id_list_1": "2", + "id_list_2": 3, + "id_list_3": "100", + } + def test_empty_list_param(): ids = [] - query, params = jsql.format_query_with_list_params('id IN :id_list', dict(id_list=ids)) - assert query == 'id IN (null)' + query, params = jsql.format_query_with_list_params( + "id IN :id_list", dict(id_list=ids) + ) + assert query == "id IN (null)" assert params == {} + def test_simple_tuple_list(): tuples = [ - (123, 'val1'), - (456, 'val2'), + (123, "val1"), + (456, "val2"), ] - query, params = jsql.format_query_with_list_params('(key1, key2) IN :key_tuple_list', dict(key_tuple_list=tuples)) - assert query == '(key1, key2) IN ((:key_tuple_list_0_0, :key_tuple_list_0_1), (:key_tuple_list_1_0, :key_tuple_list_1_1))' - assert params == {'key_tuple_list_0_0': 123, 'key_tuple_list_0_1': 'val1', 'key_tuple_list_1_0': 456, 'key_tuple_list_1_1': 'val2'} - + query, params = jsql.format_query_with_list_params( + "(key1, key2) IN :key_tuple_list", dict(key_tuple_list=tuples) + ) + assert ( + query + == "(key1, key2) IN ((:key_tuple_list_0_0, :key_tuple_list_0_1), (:key_tuple_list_1_0, :key_tuple_list_1_1))" + ) + assert params == { + "key_tuple_list_0_0": 123, + "key_tuple_list_0_1": "val1", + "key_tuple_list_1_0": 456, + "key_tuple_list_1_1": "val2", + } diff --git a/tests/test_render.py b/tests/test_render.py index 17ea6a9..19d79a1 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -1,7 +1,8 @@ -import jsql -import pytest import logging -import copy + +import pytest + +import jsql logging.basicConfig(level=logging.DEBUG) @@ -9,82 +10,114 @@ @pytest.fixture def params(): return { - 'safe': 'safe', - 'unsafe': '" unsafe', - 'tbl': [{'a': 1, 'b': 2}, {'a': 1, 'c': 2}], - 'hex_list': ('abc123', 'deadbeef'), + "safe": "safe", + "unsafe": '" unsafe', + "tbl": [{"a": 1, "b": 2}, {"a": 1, "c": 2}], + "hex_list": ("abc123", "deadbeef"), } def test_simple_bind(params): - tmpl = jsql.render('table_{{safe}}', params) - assert tmpl == 'table_safe' + tmpl = jsql.render("table_{{safe}}", params) + assert tmpl == "table_safe" + def test_unsafe_bind(params): with pytest.raises(jsql.UnsafeSqlException): - tmpl = jsql.render('table_{{unsafe}}', params) + tmpl = jsql.render("table_{{unsafe}}", params) + def test_dangerous_bind(params): - tmpl = jsql.render('`{{unsafe | dangerously_inject_sql }}`', params) + tmpl = jsql.render("`{{unsafe | dangerously_inject_sql }}`", params) assert tmpl == '`" unsafe`' + def test_missing_param(params): - tmpl = jsql.render('{% if xyz %}AND xyz=:xyz{% endif %}', params) - assert tmpl == '' + tmpl = jsql.render("{% if xyz %}AND xyz=:xyz{% endif %}", params) + assert tmpl == "" + def test_raw_comma_bind(params): tmpl = jsql.render('{{ "," if False }}', {}) - assert tmpl == '' + assert tmpl == "" with pytest.raises(jsql.UnsafeSqlException): tmpl = jsql.render('{{ "," if True }}', {}) + def test_safe_comma_bind(params): - tmpl = jsql.render('{{ comma if False }}', {}) - assert tmpl == '' - tmpl = jsql.render('{{ comma if True }}', {}) - assert tmpl == ',' + tmpl = jsql.render("{{ comma if False }}", {}) + assert tmpl == "" + tmpl = jsql.render("{{ comma if True }}", {}) + assert tmpl == "," + def test_comma_loop(params): - tmpl = jsql.render('{% for i in range(3) %} SQL {{ i + 1 }}{{ comma if not loop.last }}{% endfor %}', {}) - assert tmpl == ' SQL 1, SQL 2, SQL 3' + tmpl = jsql.render( + "{% for i in range(3) %} SQL {{ i + 1 }}{{ comma if not loop.last }}{% endfor %}", + {}, + ) + assert tmpl == " SQL 1, SQL 2, SQL 3" + def test_check_if_else_if(params): - tmpl = jsql.render('{{ safe if True else unsafe }}', params) - assert tmpl == 'safe' + tmpl = jsql.render("{{ safe if True else unsafe }}", params) + assert tmpl == "safe" with pytest.raises(jsql.UnsafeSqlException): - tmpl = jsql.render('{{ safe if False else unsafe }}', params) + tmpl = jsql.render("{{ safe if False else unsafe }}", params) + def test_check_if_else_else(params): - tmpl = jsql.render('{{ unsafe if False else safe }}', params) - assert tmpl == 'safe' + tmpl = jsql.render("{{ unsafe if False else safe }}", params) + assert tmpl == "safe" with pytest.raises(jsql.UnsafeSqlException): - tmpl = jsql.render('{{ unsafe if True else safe }}', params) + tmpl = jsql.render("{{ unsafe if True else safe }}", params) + def test_bindparam_union(params): params = params - tmpl = jsql.render(''' + tmpl = jsql.render( + """ SELECT * FROM ( {%- for row in tbl -%} {% set outer_loop = loop %} ({%- for key in tbl[0].keys() -%}:{{ bindparam(row.get(key)) }}{% if outer_loop.first %} as {{ key }}{% endif %}{{ comma if not loop.last }} {% endfor %}) {% if not loop.last %}UNION ALL {%- endif -%} {%- endfor -%} ) - ''', params) - assert tmpl == ''' + """, + params, + ) + assert ( + tmpl + == """ SELECT * FROM ( (:bp0 as a, :bp1 as b ) UNION ALL (:bp2, :bp3 ) ) - ''' - assert (params['bp0'], params['bp1'], params['bp2'], params['bp3']) == (params['tbl'][0]['a'], params['tbl'][0]['b'], params['tbl'][1]['a'], params['tbl'][1].get('b')) + """ + ) + assert (params["bp0"], params["bp1"], params["bp2"], params["bp3"]) == ( + params["tbl"][0]["a"], + params["tbl"][0]["b"], + params["tbl"][1]["a"], + params["tbl"][1].get("b"), + ) + def test_bindparam_unhex(params): - tmpl = jsql.render(''' + tmpl = jsql.render( + """ SELECT * FROM tbl WHERE binary IN ({% for hexval in hex_list %}UNHEX(:{{ bindparam(hexval) }}){{ comma if not loop.last }}{% endfor %}) - ''', params) - assert tmpl == ''' + """, + params, + ) + assert ( + tmpl + == """ SELECT * FROM tbl WHERE binary IN (UNHEX(:bp0),UNHEX(:bp1)) - ''' - assert (params['bp0'], params['bp1']) == (params['hex_list'][0], params['hex_list'][1]) - + """ + ) + assert (params["bp0"], params["bp1"]) == ( + params["hex_list"][0], + params["hex_list"][1], + ) diff --git a/tests/test_sql.py b/tests/test_sql.py index 880d58d..bc4c7ed 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,32 +1,46 @@ import jsql + def mock_execute_sql(engine, query, params): class ResultProxy: def keys(self): - return ['engine', 'query', 'params'] + return ["engine", "query", "params"] + def __iter__(self): return iter([[engine, query, params]]) + return ResultProxy() + + jsql.execute_sql = mock_execute_sql + def logparams_decorator(fn): import functools + @functools.wraps(fn) def inner(engine, template, params): - template = template.replace('HOOKPARAMS', '/* pod=xyz */') + template = template.replace("HOOKPARAMS", "/* pod=xyz */") return fn(engine, template, params) + return inner + engine = object() + + def test_sql(): - ret = jsql.sql(engine, 'SELECT 1 FROM tbl {% if customer %}WHERE customer=:customer{% endif %}', customer='abc').dict() - assert ret['engine'] == engine - assert ret['query'] == 'SELECT 1 FROM tbl WHERE customer=:customer' - assert ret['params']['customer'] == 'abc' + ret = jsql.sql( + engine, + "SELECT 1 FROM tbl {% if customer %}WHERE customer=:customer{% endif %}", + customer="abc", + ).dict() + assert ret["engine"] == engine + assert ret["query"] == "SELECT 1 FROM tbl WHERE customer=:customer" + assert ret["params"]["customer"] == "abc" + def test_sqlhook(): jsql.sql_inner = logparams_decorator(jsql.sql_inner) - ret = jsql.sql(engine, 'SELECT 1 FROM tbl WHERE 1=1 HOOKPARAMS').dict() - assert ret['query'] == 'SELECT 1 FROM tbl WHERE 1=1 /* pod=xyz */' - - + ret = jsql.sql(engine, "SELECT 1 FROM tbl WHERE 1=1 HOOKPARAMS").dict() + assert ret["query"] == "SELECT 1 FROM tbl WHERE 1=1 /* pod=xyz */" diff --git a/tests/test_sql_proxy_factory.py b/tests/test_sql_proxy_factory.py new file mode 100644 index 0000000..da992cb --- /dev/null +++ b/tests/test_sql_proxy_factory.py @@ -0,0 +1,48 @@ +import pytest + +import jsql + + +# Test SQL proxy factory +def test_sql_proxy_factory(): + def mock_execute_sql(engine, query, params): + class ResultProxy: + def keys(self): + return ["engine", "query", "params"] + + def __iter__(self): + return iter([[engine, query, params]]) + + return ResultProxy() + + jsql.execute_sql = mock_execute_sql + + engine = object() + query = "SELECT 1 FROM tbl {% if customer %}WHERE customer=:customer{% endif %}" + params = {"customer": "abc"} + ret = jsql.get_sql_proxy_factory(engine)(query, params).dict() + assert ret["engine"] == engine + assert ret["query"] == "SELECT 1 FROM tbl " + assert ret["params"]["customer"] == "abc" + + +def test_get_sql_proxy(): + def mock_execute_sql(engine, query, params): + class ResultProxy: + def keys(self): + return ["engine", "query", "params"] + + def __iter__(self): + return iter([[engine, query, params]]) + + return ResultProxy() + + jsql.execute_sql = mock_execute_sql + + engine = object() + query = "SELECT 1 FROM tbl {% if customer %}WHERE customer=:customer{% endif %}" + params = {"customer": "abc"} + ret = jsql.get_sql_proxy(engine, query, params).dict() + assert ret["engine"] == engine + assert ret["query"] == "SELECT 1 FROM tbl " + assert ret["params"]["customer"] == "abc" diff --git a/tests/test_sqlproxy.py b/tests/test_sqlproxy.py index 5cacbe1..d9e49b6 100644 --- a/tests/test_sqlproxy.py +++ b/tests/test_sqlproxy.py @@ -1,7 +1,9 @@ -import jsql import pytest -class FakeResult(): +import jsql + + +class FakeResult: def __init__(self, keys, rows): self._keys = keys self._rows = rows @@ -12,51 +14,61 @@ def keys(self): def __iter__(self): return iter(self._rows) -example = FakeResult(['id', 'name'], [ - (1, 'A'), - (2, 'B'), - (3, 'C'), - ]) -empty_example = FakeResult(['id', 'name'], []) +example = FakeResult( + ["id", "name"], + [ + (1, "A"), + (2, "B"), + (3, "C"), + ], +) + +empty_example = FakeResult(["id", "name"], []) + def test_dict(): res = jsql.SqlProxy(example) - assert res.dict() == {'id': 1, 'name': 'A'} + assert res.dict() == {"id": 1, "name": "A"} + def test_empty_dict(): res = jsql.SqlProxy(empty_example) assert res.dict() is None + def test_dicts(): res = jsql.SqlProxy(example) assert res.dicts() == [ - {'id': 1, 'name': 'A'}, - {'id': 2, 'name': 'B'}, - {'id': 3, 'name': 'C'} + {"id": 1, "name": "A"}, + {"id": 2, "name": "B"}, + {"id": 3, "name": "C"}, ] + def test_kv_map(): res = jsql.SqlProxy(example) assert res.kv_map() == { - 1: 'A', - 2: 'B', - 3: 'C', + 1: "A", + 2: "B", + 3: "C", } + def test_pk_map(): res = jsql.SqlProxy(example) assert res.pk_map() == { - 1: {'id': 1, 'name': 'A'}, - 2: {'id': 2, 'name': 'B'}, - 3: {'id': 3, 'name': 'C'} + 1: {"id": 1, "name": "A"}, + 2: {"id": 2, "name": "B"}, + 3: {"id": 3, "name": "C"}, } + def test_scalars(): res = jsql.SqlProxy(example) assert res.scalars() == [1, 2, 3] + def test_scalar_set(): res = jsql.SqlProxy(example) assert res.scalar_set() == {1, 2, 3} - diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 0000000..14ee29c --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,5 @@ +import jsql + + +def test_version() -> None: + assert jsql.__version__ == "0.7" From 70b0dc859512905756fad17fe30b2bcd48a6ec08 Mon Sep 17 00:00:00 2001 From: Yasser Tahiri Date: Sun, 13 Nov 2022 17:50:50 +0400 Subject: [PATCH 2/2] fix Branch --- .github/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 4b862bf..49ac3ad 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -3,7 +3,7 @@ name: Test Suite on: push: branches: - - main + - master pull_request: types: [opened, synchronize]