diff --git a/fastapi_template/cli.py b/fastapi_template/cli.py index a54d4c600a132d23c9160a922c548e8a686a1eb2..b4e17716b81c410a534d65bbf2278be3d83c0167 100644 --- a/fastapi_template/cli.py +++ b/fastapi_template/cli.py @@ -1,6 +1,7 @@ import re from argparse import ArgumentParser from operator import attrgetter +from termcolor import cprint from prompt_toolkit import prompt from prompt_toolkit.document import Document @@ -8,6 +9,8 @@ from prompt_toolkit.shortcuts import checkboxlist_dialog, radiolist_dialog from prompt_toolkit.validation import ValidationError, Validator from fastapi_template.input_model import ( + SUPPORTED_ORMS, + ORMS_WITHOUT_MIGRATIONS, ORM, BuilderContext, DB_INFO, @@ -199,10 +202,19 @@ def read_user_input(current_context: BuilderContext) -> BuilderContext: current_context.orm = radiolist_dialog( "ORM", text="Which ORM do you want?", - values=[(orm, orm.value) for orm in list(ORM) if orm != ORM.none], + values=[(orm, orm.value) for orm in SUPPORTED_ORMS[current_context.db]], ).run() if current_context.orm is None: raise KeyboardInterrupt() + if ( + current_context.orm is not None + and current_context.orm != ORM.none + and current_context.orm not in SUPPORTED_ORMS.get(current_context.db, []) + ): + cprint("This ORM is not supported by chosen database.", "red") + raise KeyboardInterrupt() + if current_context.orm in ORMS_WITHOUT_MIGRATIONS: + current_context.enable_migrations = False if current_context.ci_type is None: current_context.ci_type = radiolist_dialog( "CI", diff --git a/fastapi_template/input_model.py b/fastapi_template/input_model.py index da11f2e489bd1ac3275424626c596307759c8bb7..7d9fc2adc23324dfddaeed87bc933b3499e0c9dc 100644 --- a/fastapi_template/input_model.py +++ b/fastapi_template/input_model.py @@ -24,6 +24,7 @@ class ORM(enum.Enum): ormar = "ormar" sqlalchemy = "sqlalchemy" tortoise = "tortoise" + psycopg = "psycopg" class Database(BaseModel): @@ -69,6 +70,28 @@ DB_INFO = { ), } +SUPPORTED_ORMS = { + DatabaseType.postgresql: [ + ORM.ormar, + ORM.psycopg, + ORM.tortoise, + ORM.sqlalchemy, + ], + DatabaseType.sqlite: [ + ORM.ormar, + ORM.tortoise, + ORM.sqlalchemy, + ], + DatabaseType.mysql: [ + ORM.ormar, + ORM.tortoise, + ORM.sqlalchemy, + ] +} + +ORMS_WITHOUT_MIGRATIONS = [ + ORM.psycopg, +] class BuilderContext(BaseModel): """Options for project generation.""" diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/conditional_files.json b/fastapi_template/template/{{cookiecutter.project_name}}/conditional_files.json index b875a9bc84d6df0520f25a2eef5e9228a80ff1d9..478fdc7fa13e827f2fc100d16b65279a983e7f04 100644 --- a/fastapi_template/template/{{cookiecutter.project_name}}/conditional_files.json +++ b/fastapi_template/template/{{cookiecutter.project_name}}/conditional_files.json @@ -80,6 +80,8 @@ "{{cookiecutter.project_name}}/db_ormar/models/dummy_model.py", "{{cookiecutter.project_name}}/db_tortoise/dao", "{{cookiecutter.project_name}}/db_tortoise/models/dummy_model.py", + "{{cookiecutter.project_name}}/db_psycopg/dao", + "{{cookiecutter.project_name}}/db_psycopg/models/dummy_model.py", "{{cookiecutter.project_name}}/tests/test_dummy.py", "{{cookiecutter.project_name}}/db_sa/migrations/versions/2021-08-16-16-55_2b7380507a71.py", "{{cookiecutter.project_name}}/db_ormar/migrations/versions/2021-08-16-16-55_2b7380507a71.py", @@ -114,6 +116,12 @@ "{{cookiecutter.project_name}}/db_ormar" ] }, + "PsycoPG": { + "enabled": "{{cookiecutter.orm == 'psycopg'}}", + "resources": [ + "{{cookiecutter.project_name}}/db_psycopg" + ] + }, "Postgresql DB": { "enabled": "{{cookiecutter.db_info.name == 'postgresql'}}", "resources": [ diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/pyproject.toml b/fastapi_template/template/{{cookiecutter.project_name}}/pyproject.toml index 43fb893f0e2eaa36cf35d4ddc05c47e6a75024ac..65d6b302014f353af4a440534f22c35343c3fda1 100644 --- a/fastapi_template/template/{{cookiecutter.project_name}}/pyproject.toml +++ b/fastapi_template/template/{{cookiecutter.project_name}}/pyproject.toml @@ -67,6 +67,9 @@ aioredis = {version = "^2.0.1", extras = ["hiredis"]} {%- if cookiecutter.self_hosted_swagger == 'True' %} aiofiles = "^0.8.0" {%- endif %} +{%- if cookiecutter.orm == "psycopg" %} +psycopg = { version = "^3.0.11", extras = ["binary", "pool"] } +{%- endif %} httptools = "^0.3.0" [tool.poetry.dev-dependencies] diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/replaceable_files.json b/fastapi_template/template/{{cookiecutter.project_name}}/replaceable_files.json index c0d903c64fcfa5367ccfc82a6a285c641369eadb..981129c03448afaaa06f03013d0b3d3e019405d9 100644 --- a/fastapi_template/template/{{cookiecutter.project_name}}/replaceable_files.json +++ b/fastapi_template/template/{{cookiecutter.project_name}}/replaceable_files.json @@ -2,6 +2,7 @@ "{{cookiecutter.project_name}}/db": [ "{{cookiecutter.project_name}}/db_sa", "{{cookiecutter.project_name}}/db_ormar", - "{{cookiecutter.project_name}}/db_tortoise" + "{{cookiecutter.project_name}}/db_tortoise", + "{{cookiecutter.project_name}}/db_psycopg" ] } diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/conftest.py b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/conftest.py index 4ca480c6b040f206334c339468fc95531603ff2b..a95af5f83ea1da516851824fab15135ca0a291cd 100644 --- a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/conftest.py +++ b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/conftest.py @@ -31,6 +31,12 @@ nest_asyncio.apply() from sqlalchemy.engine import create_engine from {{cookiecutter.project_name}}.db.config import database from {{cookiecutter.project_name}}.db.utils import create_database, drop_database +{%- elif cookiecutter.orm == "psycopg" %} +from psycopg import AsyncConnection +from psycopg_pool import AsyncConnectionPool + +from {{cookiecutter.project_name}}.db.dependencies import get_db_session + {%- endif %} @@ -146,6 +152,99 @@ async def initialize_db() -> AsyncGenerator[None, None]: await database.disconnect() drop_database() +{%- elif cookiecutter.orm == "psycopg" %} + +async def drop_db() -> None: + """Drops database after tests.""" + pool = AsyncConnectionPool(conninfo=str(settings.db_url.with_path("/postgres"))) + await pool.wait() + async with pool.connection() as conn: + await conn.set_autocommit(True) + await conn.execute( + "SELECT pg_terminate_backend(pg_stat_activity.pid) " # noqa: S608 + "FROM pg_stat_activity " + "WHERE pg_stat_activity.datname = %(dbname)s " + "AND pid <> pg_backend_pid();", + params={ + "dbname": settings.db_base, + } + ) + await conn.execute( + f"DROP DATABASE {settings.db_base}", + ) + await pool.close() + + +async def create_db() -> None: # noqa: WPS217 + """Creates database for tests.""" + pool = AsyncConnectionPool(conninfo=str(settings.db_url.with_path("/postgres"))) + await pool.wait() + async with pool.connection() as conn_check: + res = await conn_check.execute( + "SELECT 1 FROM pg_database WHERE datname=%(dbname)s", + params={ + "dbname": settings.db_base, + } + ) + db_exists = False + row = await res.fetchone() + if row is not None: + db_exists = row[0] + + if db_exists: + await drop_db() + + async with pool.connection() as conn_create: + await conn_create.set_autocommit(True) + await conn_create.execute( + f"CREATE DATABASE {settings.db_base};", + ) + await pool.close() + + +async def create_tables(connection: AsyncConnection[Any]) -> None: + """ + Create tables for your database. + + Since psycopg doesn't have migration tool, + you must create your tables for tests. + + :param connection: connection to database. + """ + {%- if cookiecutter.add_dummy == 'True' %} + await connection.execute( + "CREATE TABLE dummy (" + "id SERIAL primary key," + "name VARCHAR(200)" + ");" + ) + {%- endif %} + pass # noqa: WPS420 + + +@pytest.fixture +async def dbsession() -> AsyncGenerator[AsyncConnection[Any], None]: + """ + Creates connection to some test database. + + This connection must be used in tests and for application. + + :yield: connection to database. + """ + await create_db() + pool = AsyncConnectionPool(conninfo=str(settings.db_url)) + await pool.wait() + + async with pool.connection() as create_conn: + await create_tables(create_conn) + + try: + async with pool.connection() as conn: + yield conn + finally: + await pool.close() + await drop_db() + {%- endif %} @@ -167,6 +266,8 @@ async def fake_redis() -> AsyncGenerator[FakeRedis, None]: def fastapi_app( {%- if cookiecutter.orm == "sqlalchemy" %} dbsession: AsyncSession, + {%- elif cookiecutter.orm == "psycopg" %} + dbsession: AsyncConnection[Any], {%- endif %} {% if cookiecutter.enable_redis == "True" -%} fake_redis: FakeRedis, @@ -178,7 +279,7 @@ def fastapi_app( :return: fastapi app with mocked dependencies. """ application = get_app() - {% if cookiecutter.orm == "sqlalchemy" -%} + {% if cookiecutter.orm in ["sqlalchemy", "psycopg"] -%} application.dependency_overrides[get_db_session] = lambda: dbsession {%- endif %} {%- if cookiecutter.enable_redis == "True" %} diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/dao/dummy_dao.py b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/dao/dummy_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..02a4c09c308c66e829227ada6c6b19c9a33a7f3b --- /dev/null +++ b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/dao/dummy_dao.py @@ -0,0 +1,81 @@ +from termios import OFDEL +from {{cookiecutter.project_name}}.db.models.dummy_model import DummyModel +from typing import Any + +from fastapi import Depends +from psycopg import AsyncConnection +from psycopg.rows import class_row +from {{cookiecutter.project_name}}.db.dependencies import get_db_session +from typing import List, Optional + +class DummyDAO: + """Class for accessing dummy table.""" + + def __init__( + self, + connection: AsyncConnection[Any] = Depends(get_db_session), + ): + self.connection = connection + + + async def create_dummy_model(self, name: str) -> None: + """ + Creates new dummy in a database. + + :param name: name of a dummy. + """ + async with self.connection.cursor( + binary=True, + ) as cur: + await cur.execute( + "INSERT INTO dummy (name) VALUES (%(name)s);", + params={ + "name": name, + } + ) + + async def get_all_dummies(self, limit: int, offset: int) -> List[DummyModel]: + """ + Get all dummy models with limit/offset pagination. + + :param limit: limit of dummies. + :param offset: offset of dummies. + :return: stream of dummies. + """ + async with self.connection.cursor( + binary=True, + row_factory=class_row(DummyModel) + ) as cur: + res = await cur.execute( + "SELECT id, name FROM dummy LIMIT %(limit)s OFFSET %(offset)s;", + params={ + "limit": limit, + "offset": offset, + } + ) + return await res.fetchall() + + async def filter( + self, + name: Optional[str] = None, + ) -> List[DummyModel]: + """ + Get specific dummy model. + + :param name: name of dummy instance. + :return: dummy models. + """ + async with self.connection.cursor( + binary=True, + row_factory=class_row(DummyModel) + ) as cur: + if name is not None: + res = await cur.execute( + "SELECT id, name FROM dummy WHERE name=%(name)s;", + params={ + "name": name, + } + ) + else: + res = await cur.execute("SELECT id, name FROM dummy;") + return await res.fetchall() diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/dependencies.py b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..6cbf5d959a275af79ebd667f890b145873ba4c5e --- /dev/null +++ b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/dependencies.py @@ -0,0 +1,20 @@ +from typing import Any, AsyncGenerator + +from psycopg import AsyncConnection +from starlette.requests import Request + + +async def get_db_session( + request: Request, +) -> AsyncGenerator[AsyncConnection[Any], None]: + """ + Create and get database connection. + + :param request: current request. + :yield: database connection. + """ + async with request.app.state.db_pool.connection() as conn: + try: + yield conn + except Exception: # noqa: S110 + pass # noqa: WPS420 diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/models/dummy_model.py b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/models/dummy_model.py new file mode 100644 index 0000000000000000000000000000000000000000..93bcd833fc0a31b29d3b045017e08bc2b69fd123 --- /dev/null +++ b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/db_psycopg/models/dummy_model.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + +class DummyModel(BaseModel): + """Dummy model for database.""" + + id: int + name: str diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/tests/test_dummy.py b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/tests/test_dummy.py index 460f909e82652cab125aa071b8e16577a793804b..3948489a69918c11e948c498fe22f9220dd6ebf0 100644 --- a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/tests/test_dummy.py +++ b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/tests/test_dummy.py @@ -2,7 +2,12 @@ import uuid import pytest from httpx import AsyncClient from fastapi import FastAPI +from typing import Any +{%- if cookiecutter.orm == 'sqlalchemy' %} from sqlalchemy.ext.asyncio import AsyncSession +{%- elif cookiecutter.orm == 'psycopg' %} +from psycopg.connection_async import AsyncConnection +{%- endif %} from starlette import status from {{cookiecutter.project_name}}.db.models.dummy_model import DummyModel from {{cookiecutter.project_name}}.db.dao.dummy_dao import DummyDAO @@ -13,6 +18,8 @@ async def test_creation( client: AsyncClient, {%- if cookiecutter.orm == "sqlalchemy" %} dbsession: AsyncSession, + {%- elif cookiecutter.orm == "psycopg" %} + dbsession: AsyncConnection[Any], {%- endif %} ) -> None: """Tests dummy instance creation.""" @@ -22,7 +29,7 @@ async def test_creation( "name": test_name }) assert response.status_code == status.HTTP_200_OK - {%- if cookiecutter.orm == "sqlalchemy" %} + {%- if cookiecutter.orm in ["sqlalchemy", "psycopg"] %} dao = DummyDAO(dbsession) {%- elif cookiecutter.orm in ["tortoise", "ormar"] %} dao = DummyDAO() @@ -37,10 +44,12 @@ async def test_getting( client: AsyncClient, {%- if cookiecutter.orm == "sqlalchemy" %} dbsession: AsyncSession, + {%- elif cookiecutter.orm == "psycopg" %} + dbsession: AsyncConnection[Any], {%- endif %} ) -> None: """Tests dummy instance retrieval.""" - {%- if cookiecutter.orm == "sqlalchemy" %} + {%- if cookiecutter.orm in ["sqlalchemy", "psycopg"] %} dao = DummyDAO(dbsession) {%- elif cookiecutter.orm in ["tortoise", "ormar"] %} dao = DummyDAO() diff --git a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/lifetime.py b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/lifetime.py index 958c8e8e9f8bbb30e327da2c89e70787c0e301a2..4d67207b4608f122bd7d7855d2c2ed5279701b0b 100644 --- a/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/lifetime.py +++ b/fastapi_template/template/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/web/lifetime.py @@ -17,6 +17,20 @@ from {{cookiecutter.project_name}}.db.models import load_all_models {%- endif %} {%- endif %} + +{%- if cookiecutter.orm == "psycopg" %} +import psycopg_pool + +async def _setup_db(app: FastAPI) -> None: + """ + Creates connection pool for timescaledb. + + :param app: current FastAPI app. + """ + app.state.db_pool = psycopg_pool.AsyncConnectionPool(conninfo=str(settings.db_url)) + await app.state.db_pool.wait() +{%- endif %} + {%- if cookiecutter.orm == "sqlalchemy" %} from asyncio import current_task from sqlalchemy.ext.asyncio import ( @@ -26,7 +40,7 @@ from sqlalchemy.ext.asyncio import ( ) from sqlalchemy.orm import sessionmaker -{%- if cookiecutter.db_info.name != "none" and cookiecutter.enable_migrations == "False" %} +{%- if cookiecutter.enable_migrations == "False" %} from {{cookiecutter.project_name}}.db.meta import meta from {{cookiecutter.project_name}}.db.models import load_all_models {%- endif %} @@ -67,7 +81,7 @@ def _setup_redis(app: FastAPI) -> None: ) {%- endif %} -{%- if cookiecutter.db_info.name != "none" and cookiecutter.enable_migrations == "False" %} +{%- if cookiecutter.enable_migrations == "False" %} {%- if cookiecutter.orm in ["ormar", "sqlalchemy"] %} async def _create_tables() -> None: """Populates tables in the database.""" @@ -102,6 +116,8 @@ def startup(app: FastAPI) -> Callable[[], Awaitable[None]]: _setup_db(app) {%- elif cookiecutter.orm == "ormar" %} await database.connect() + {%- elif cookiecutter.orm == "psycopg" %} + await _setup_db(app) {%- endif %} {%- if cookiecutter.db_info.name != "none" and cookiecutter.enable_migrations == "False" %} {%- if cookiecutter.orm in ["ormar", "sqlalchemy"] %} @@ -129,6 +145,8 @@ def shutdown(app: FastAPI) -> Callable[[], Awaitable[None]]: await app.state.db_engine.dispose() {% elif cookiecutter.orm == "ormar" %} await database.disconnect() + {%- elif cookiecutter.orm == "psycopg" %} + await app.state.db_pool.close() {%- endif %} {%- if cookiecutter.enable_redis == "True" %} await app.state.redis_pool.disconnect() diff --git a/fastapi_template/tests/test_generator.py b/fastapi_template/tests/test_generator.py index 5492392d4db67e8de1ba5bb60ab7a9daa257a42e..1d1f6c64330a2a7751762ddb975cee9df722c868 100644 --- a/fastapi_template/tests/test_generator.py +++ b/fastapi_template/tests/test_generator.py @@ -39,7 +39,16 @@ def test_default_with_db(default_context: BuilderContext, db: DatabaseType, orm: run_default_check(init_context(default_context, db, orm)) -@pytest.mark.parametrize("orm", [ORM.sqlalchemy, ORM.tortoise, ORM.ormar]) +@pytest.mark.parametrize( + "orm", + [ + ORM.psycopg, + ] +) +def test_pg_drivers(default_context: BuilderContext, orm: ORM): + run_default_check(init_context(default_context, DatabaseType.postgresql, orm)) + +@pytest.mark.parametrize("orm", [ORM.sqlalchemy, ORM.tortoise, ORM.ormar, ORM.psycopg]) def test_without_routers(default_context: BuilderContext, orm: ORM): context = init_context(default_context, DatabaseType.postgresql, orm) context.enable_routers = False @@ -58,7 +67,7 @@ def test_with_selfhosted_swagger(default_context: BuilderContext): run_default_check(default_context) -@pytest.mark.parametrize("orm", [ORM.sqlalchemy, ORM.tortoise, ORM.ormar]) +@pytest.mark.parametrize("orm", [ORM.sqlalchemy, ORM.tortoise, ORM.ormar, ORM.psycopg]) def test_without_dummy(default_context: BuilderContext, orm: ORM): context = init_context(default_context, DatabaseType.postgresql, orm) context.add_dummy = False