From d2761278550cb46627afe3b4d7562654842d3af0 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin <win10@list.ru> Date: Fri, 9 Oct 2020 01:33:59 +0400 Subject: [PATCH] Added database driver selector. Signed-off-by: Pavel Kirilin <win10@list.ru> --- cookiecutter.json | 6 ++- {{cookiecutter.project_name}}/pyproject.toml | 6 ++- {{cookiecutter.project_name}}/scheduler.py | 1 - .../src/api/dummy_db/schema.py | 7 ++- .../src/models/dummy_db_model.py | 39 ++++++++++---- {{cookiecutter.project_name}}/src/server.py | 16 +++++- .../src/services/db.py | 53 ++++++++++--------- 7 files changed, 89 insertions(+), 39 deletions(-) diff --git a/cookiecutter.json b/cookiecutter.json index 4a73c33..52c68cb 100644 --- a/cookiecutter.json +++ b/cookiecutter.json @@ -15,10 +15,14 @@ true, false ], - "add_scheduler":[ + "add_scheduler": [ true, false ], + "pg_driver": [ + "aiopg", + "asyncpgsa" + ], "_extensions": [ "cookiecutter.extensions.RandomStringExtension" ] diff --git a/{{cookiecutter.project_name}}/pyproject.toml b/{{cookiecutter.project_name}}/pyproject.toml index f91f30d..ed6319a 100644 --- a/{{cookiecutter.project_name}}/pyproject.toml +++ b/{{cookiecutter.project_name}}/pyproject.toml @@ -12,8 +12,12 @@ sqlalchemy = "^1.3.19" loguru = "^0.5.2" alembic = "^1.4.3" httpx = "^0.14.3" -psycopg2 = "^2.8.6" +{% if cookiecutter.pg_driver == "aiopg" -%} aiopg = "^1.0.0" +{% else %} +psycopg2 = "^2.8.6" +asyncpgsa = "^0.26.3" +{% endif %} {% if cookiecutter.add_redis == "True" -%} aioredis = "^1.3.1" {% endif %} diff --git a/{{cookiecutter.project_name}}/scheduler.py b/{{cookiecutter.project_name}}/scheduler.py index daec45f..513db02 100644 --- a/{{cookiecutter.project_name}}/scheduler.py +++ b/{{cookiecutter.project_name}}/scheduler.py @@ -1,5 +1,4 @@ import asyncio -import time import aioschedule as schedule from loguru import logger diff --git a/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py b/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py index b54ba05..67ed75c 100644 --- a/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py +++ b/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py @@ -1,7 +1,10 @@ import uuid from typing import Optional -from pydantic import Field, BaseConfig +from pydantic import Field +{% if cookiecutter.pg_driver == "aiopg" -%} +from pydantic import BaseConfig +{% endif %} from pydantic.main import BaseModel @@ -13,8 +16,10 @@ class BaseDummyModel(BaseModel): class ReturnDummyModel(BaseDummyModel): id: uuid.UUID + {% if cookiecutter.pg_driver == "aiopg" -%} class Config(BaseConfig): orm_mode = True + {% endif %} class UpdateDummyModel(BaseModel): diff --git a/{{cookiecutter.project_name}}/src/models/dummy_db_model.py b/{{cookiecutter.project_name}}/src/models/dummy_db_model.py index 4b51e02..27677ee 100644 --- a/{{cookiecutter.project_name}}/src/models/dummy_db_model.py +++ b/{{cookiecutter.project_name}}/src/models/dummy_db_model.py @@ -3,9 +3,13 @@ from typing import Optional, List from sqlalchemy import Column, String from sqlalchemy.dialects.postgresql import UUID - +{% if cookiecutter.pg_driver == "aiopg" -%} from src.services.db import Base, db_engine - +{% else %} +from src.services.db import Base +from asyncpgsa import pg +from asyncpg import Record +{% endif %} class DummyDBModel(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -19,19 +23,28 @@ class DummyDBModel(Base): name: str, surname: str, ) -> None: + query = cls.insert_query( + name=name, + surname=surname, + ) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: - await conn.execute( - cls.insert_query( - name=name, - surname=surname, - ) - ) + await conn.execute(query) + {% else %} + await pg.execute(query) + {% endif %} + + @classmethod async def delete(cls, dummy_id: uuid.UUID) -> None: query = cls.delete_query().where(cls.id == dummy_id) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: await conn.execute(query) + {% else %} + await pg.fetchrow(query) + {% endif %} @classmethod async def update(cls, @@ -46,15 +59,19 @@ class DummyDBModel(Base): if surname: new_values[cls.surname] = surname query = cls.update_query().where(cls.id == dummy_id).values(new_values) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: await conn.execute(query) + {% else %} + await pg.fetchrow(query) + {% endif %} @classmethod async def filter(cls, *, dummy_id: Optional[uuid.UUID] = None, name: Optional[str] = None, surname: Optional[str] = None - ) -> List["DummyDBModel"]: + ) -> {% if cookiecutter.pg_driver == "aiopg" -%}List["DummyDBModel"]{% else %}List[Record]{% endif %}: query = cls.select_query() if dummy_id: query = query.where(cls.id == dummy_id) @@ -62,6 +79,10 @@ class DummyDBModel(Base): query = query.where(cls.name == name) if surname: query = query.where(cls.surname == surname) + {% if cookiecutter.pg_driver == "aiopg" -%} async with db_engine.client.acquire() as conn: cursor = await conn.execute(query) return await cursor.fetchall() + {% else %} + return await pg.fetch(query) + {% endif %} diff --git a/{{cookiecutter.project_name}}/src/server.py b/{{cookiecutter.project_name}}/src/server.py index c569502..03c59a4 100644 --- a/{{cookiecutter.project_name}}/src/server.py +++ b/{{cookiecutter.project_name}}/src/server.py @@ -7,7 +7,12 @@ from loguru import logger from starlette.requests import Request from src.api import api_router +{% if cookiecutter.pg_driver == "aiopg" -%} from src.services.db import db_engine +{% else %} +from src.services.db import db_url +from asyncpgsa import pg +{% endif %} {% if cookiecutter.add_redis == "True" -%} from src.services.redis import redis {% endif %} @@ -36,15 +41,24 @@ async def startup() -> None: {% if cookiecutter.add_redis == "True" -%} await redis.create_pool() {% endif %} + {% if cookiecutter.pg_driver == "aiopg" -%} await db_engine.connect() + {% else %} + await pg.init(str(db_url)) + {% endif %} @app.on_event("shutdown") async def shutdown() -> None: - await db_engine.close() {% if cookiecutter.add_redis == "True" -%} await redis.shutdown() {% endif %} + {% if cookiecutter.pg_driver == "aiopg" -%} + await db_engine.close() + {% else %} + await pg.pool.close() + {% endif %} + @app.middleware("http") diff --git a/{{cookiecutter.project_name}}/src/services/db.py b/{{cookiecutter.project_name}}/src/services/db.py index 96a87c7..b44fe50 100644 --- a/{{cookiecutter.project_name}}/src/services/db.py +++ b/{{cookiecutter.project_name}}/src/services/db.py @@ -1,7 +1,10 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict import sqlalchemy as sa +{% if cookiecutter.pg_driver == "aiopg" -%} +from typing import Optional from aiopg.sa import Engine, create_engine +{% endif %} from sqlalchemy import MetaData, Table from sqlalchemy.engine.url import URL, make_url from sqlalchemy.ext.declarative import as_declarative, declared_attr @@ -22,30 +25,6 @@ db_url = make_url( meta = MetaData() - -class DBEngine: - def __init__(self, connection_url: str) -> None: - self.dsn = connection_url - self.engine: Optional[Engine] = None - - async def connect(self) -> None: - self.engine = await create_engine(dsn=self.dsn, maxsize=100) - - @property - def client(self) -> Engine: - if self.engine: - return self.engine - raise Exception("Not connected to database") - - async def close(self) -> None: - if self.engine: - self.engine.close() - await self.engine.wait_closed() - - -db_engine = DBEngine(str(db_url)) - - @as_declarative(metadata=meta) class Base: """Base class for all models""" @@ -92,3 +71,27 @@ class Base: def as_dict(self) -> Dict[str, Any]: return {c.name: getattr(self, c.key) for c in self.__table__.columns} + +{% if cookiecutter.pg_driver == "aiopg" -%} +class DBEngine: + def __init__(self, connection_url: str) -> None: + self.dsn = connection_url + self.engine: Optional[Engine] = None + + async def connect(self) -> None: + self.engine = await create_engine(dsn=self.dsn, maxsize=100) + + @property + def client(self) -> Engine: + if self.engine: + return self.engine + raise Exception("Not connected to database") + + async def close(self) -> None: + if self.engine: + self.engine.close() + await self.engine.wait_closed() + + +db_engine = DBEngine(str(db_url)) +{% endif %} -- GitLab