diff --git a/cookiecutter.json b/cookiecutter.json index 4a73c333e6ea068e0601682b6ca8fe6484e47227..52c68cb190d76fc184368aaa877a4ba1d2225eb8 100644 --- a/cookiecutter.json +++ b/cookiecutter.json @@ -15,10 +15,14 @@ true, false ], - "add_scheduler":[ + "add_scheduler": [ true, false ], + "pg_driver": [ + "aiopg", + "asyncpgsa" + ], "_extensions": [ "cookiecutter.extensions.RandomStringExtension" ] diff --git a/{{cookiecutter.project_name}}/pyproject.toml b/{{cookiecutter.project_name}}/pyproject.toml index f91f30dbcdfa9cf428800b9d9249e75c98b15a6c..ed6319a320f1433c50ea9695a5d29f9eb8fe4ef5 100644 --- a/{{cookiecutter.project_name}}/pyproject.toml +++ b/{{cookiecutter.project_name}}/pyproject.toml @@ -12,8 +12,12 @@ sqlalchemy = "^1.3.19" loguru = "^0.5.2" alembic = "^1.4.3" httpx = "^0.14.3" -psycopg2 = "^2.8.6" +{% if cookiecutter.pg_driver == "aiopg" -%} aiopg = "^1.0.0" +{% else %} +psycopg2 = "^2.8.6" +asyncpgsa = "^0.26.3" +{% endif %} {% if cookiecutter.add_redis == "True" -%} aioredis = "^1.3.1" {% endif %} diff --git a/{{cookiecutter.project_name}}/scheduler.py b/{{cookiecutter.project_name}}/scheduler.py index daec45feaa8b08ab5c43ab88f4406fd49aac2029..513db02465a29158445c7224cafacdcaeb6aa8e8 100644 --- a/{{cookiecutter.project_name}}/scheduler.py +++ b/{{cookiecutter.project_name}}/scheduler.py @@ -1,5 +1,4 @@ import asyncio -import time import aioschedule as schedule from loguru import logger diff --git a/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py b/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py index b54ba054364310f64eec64e801d8826395c65ba0..67ed75c61c0454955cc57c944290b1956eeaed8c 100644 --- a/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py +++ b/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py @@ -1,7 +1,10 @@ import uuid from typing import Optional -from pydantic import Field, BaseConfig +from pydantic import Field +{% if cookiecutter.pg_driver == "aiopg" -%} +from pydantic import BaseConfig +{% endif %} from pydantic.main import BaseModel @@ -13,8 +16,10 @@ class BaseDummyModel(BaseModel): class ReturnDummyModel(BaseDummyModel): id: uuid.UUID + {% if cookiecutter.pg_driver == "aiopg" -%} class Config(BaseConfig): orm_mode = True + {% endif %} class UpdateDummyModel(BaseModel): diff --git a/{{cookiecutter.project_name}}/src/models/dummy_db_model.py b/{{cookiecutter.project_name}}/src/models/dummy_db_model.py index 4b51e02acd1169ac599f38f5d9506fda947837b2..27677ee7f9f7935c68d61970f58f66b2585e396c 100644 --- a/{{cookiecutter.project_name}}/src/models/dummy_db_model.py +++ b/{{cookiecutter.project_name}}/src/models/dummy_db_model.py @@ -3,9 +3,13 @@ from typing import Optional, List from sqlalchemy import Column, String from sqlalchemy.dialects.postgresql import UUID - +{% if cookiecutter.pg_driver == "aiopg" -%} from src.services.db import Base, db_engine - +{% else %} +from src.services.db import Base +from asyncpgsa import pg +from asyncpg import Record +{% endif %} class DummyDBModel(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -19,19 +23,28 @@ class DummyDBModel(Base): name: str, surname: str, ) -> None: + query = cls.insert_query( + name=name, + surname=surname, + ) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: - await conn.execute( - cls.insert_query( - name=name, - surname=surname, - ) - ) + await conn.execute(query) + {% else %} + await pg.execute(query) + {% endif %} + + @classmethod async def delete(cls, dummy_id: uuid.UUID) -> None: query = cls.delete_query().where(cls.id == dummy_id) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: await conn.execute(query) + {% else %} + await pg.fetchrow(query) + {% endif %} @classmethod async def update(cls, @@ -46,15 +59,19 @@ class DummyDBModel(Base): if surname: new_values[cls.surname] = surname query = cls.update_query().where(cls.id == dummy_id).values(new_values) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: await conn.execute(query) + {% else %} + await pg.fetchrow(query) + {% endif %} @classmethod async def filter(cls, *, dummy_id: Optional[uuid.UUID] = None, name: Optional[str] = None, surname: Optional[str] = None - ) -> List["DummyDBModel"]: + ) -> {% if cookiecutter.pg_driver == "aiopg" -%}List["DummyDBModel"]{% else %}List[Record]{% endif %}: query = cls.select_query() if dummy_id: query = query.where(cls.id == dummy_id) @@ -62,6 +79,10 @@ class DummyDBModel(Base): query = query.where(cls.name == name) if surname: query = query.where(cls.surname == surname) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: cursor = await conn.execute(query) return await cursor.fetchall() + {% else %} + return await pg.fetch(query) + {% endif %} diff --git a/{{cookiecutter.project_name}}/src/server.py b/{{cookiecutter.project_name}}/src/server.py index c56950255b9a11e339794ccdb78719b64b74198d..03c59a42051811cb34a9810d9eac83aaf1c54faf 100644 --- a/{{cookiecutter.project_name}}/src/server.py +++ b/{{cookiecutter.project_name}}/src/server.py @@ -7,7 +7,12 @@ from loguru import logger from starlette.requests import Request from src.api import api_router +{% if cookiecutter.pg_driver == "aiopg" -%} from src.services.db import db_engine +{% else %} +from src.services.db import db_url +from asyncpgsa import pg +{% endif %} {% if cookiecutter.add_redis == "True" -%} from src.services.redis import redis {% endif %} @@ -36,15 +41,24 @@ async def startup() -> None: {% if cookiecutter.add_redis == "True" -%} await redis.create_pool() {% endif %} + {% if cookiecutter.pg_driver == "aiopg" -%} await db_engine.connect() + {% else %} + await pg.init(str(db_url)) + {% endif %} @app.on_event("shutdown") async def shutdown() -> None: - await db_engine.close() {% if cookiecutter.add_redis == "True" -%} await redis.shutdown() {% endif %} + {% if cookiecutter.pg_driver == "aiopg" -%} + await db_engine.close() + {% else %} + await pg.pool.close() + {% endif %} + @app.middleware("http") diff --git a/{{cookiecutter.project_name}}/src/services/db.py b/{{cookiecutter.project_name}}/src/services/db.py index 96a87c73473e7490ca81d5ab4ffca7612272a883..b44fe507d976c7106a16dc461922d00e263b7b37 100644 --- a/{{cookiecutter.project_name}}/src/services/db.py +++ b/{{cookiecutter.project_name}}/src/services/db.py @@ -1,7 +1,10 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict import sqlalchemy as sa +{% if cookiecutter.pg_driver == "aiopg" -%} +from typing import Optional from aiopg.sa import Engine, create_engine +{% endif %} from sqlalchemy import MetaData, Table from sqlalchemy.engine.url import URL, make_url from sqlalchemy.ext.declarative import as_declarative, declared_attr @@ -22,30 +25,6 @@ db_url = make_url( meta = MetaData() - -class DBEngine: - def __init__(self, connection_url: str) -> None: - self.dsn = connection_url - self.engine: Optional[Engine] = None - - async def connect(self) -> None: - self.engine = await create_engine(dsn=self.dsn, maxsize=100) - - @property - def client(self) -> Engine: - if self.engine: - return self.engine - raise Exception("Not connected to database") - - async def close(self) -> None: - if self.engine: - self.engine.close() - await self.engine.wait_closed() - - -db_engine = DBEngine(str(db_url)) - - @as_declarative(metadata=meta) class Base: """Base class for all models""" @@ -92,3 +71,27 @@ class Base: def as_dict(self) -> Dict[str, Any]: return {c.name: getattr(self, c.key) for c in self.__table__.columns} + +{% if cookiecutter.pg_driver == "aiopg" -%} +class DBEngine: + def __init__(self, connection_url: str) -> None: + self.dsn = connection_url + self.engine: Optional[Engine] = None + + async def connect(self) -> None: + self.engine = await create_engine(dsn=self.dsn, maxsize=100) + + @property + def client(self) -> Engine: + if self.engine: + return self.engine + raise Exception("Not connected to database") + + async def close(self) -> None: + if self.engine: + self.engine.close() + await self.engine.wait_closed() + + +db_engine = DBEngine(str(db_url)) +{% endif %}