diff --git a/cookiecutter.json b/cookiecutter.json index 52c68cb190d76fc184368aaa877a4ba1d2225eb8..c1e99d6d8e3385788e8114d90dba8e83d0fb7056 100644 --- a/cookiecutter.json +++ b/cookiecutter.json @@ -19,10 +19,6 @@ true, false ], - "pg_driver": [ - "aiopg", - "asyncpgsa" - ], "_extensions": [ "cookiecutter.extensions.RandomStringExtension" ] diff --git a/{{cookiecutter.project_name}}/migrations/env.py b/{{cookiecutter.project_name}}/migrations/env.py index 0216029d3c2404f34f07777286de4c5fef8d72df..9df59ac693fc42d43df76687484dcf8022784e65 100644 --- a/{{cookiecutter.project_name}}/migrations/env.py +++ b/{{cookiecutter.project_name}}/migrations/env.py @@ -23,7 +23,7 @@ fileConfig(config.config_file_name) # from myapp import mymodel # target_metadata = mymodel.Base.metadata # noqa -from src.services.db import meta as target_metadata +from src.services.db import db_meta as target_metadata from src.models import * # isort:skip diff --git a/{{cookiecutter.project_name}}/pyproject.toml b/{{cookiecutter.project_name}}/pyproject.toml index cccff169f0fbd8b7c362e775b5b3efc14ca78383..dfc579c2123a9867eae42e4e27cdc70e1d297b85 100644 --- a/{{cookiecutter.project_name}}/pyproject.toml +++ b/{{cookiecutter.project_name}}/pyproject.toml @@ -15,12 +15,7 @@ httpx = "^0.14.3" ujson = "^4.0.1" gunicorn = "^20.0.4" httptools = "^0.1.1" -{% if cookiecutter.pg_driver == "aiopg" -%} aiopg = "^1.0.0" -{% else %} -psycopg2 = "^2.8.6" -asyncpgsa = "^0.26.3" -{% endif %} {% if cookiecutter.add_redis == "True" -%} aioredis = "^1.3.1" {% endif %} diff --git a/{{cookiecutter.project_name}}/src/api/dummy_db/routes.py b/{{cookiecutter.project_name}}/src/api/dummy_db/routes.py index dd525823eeed7bb5a26a89508c5c0ccef6fdc9f1..0efaf8c36bf826e72a87625b60d5ed800c18fd8d 100644 --- a/{{cookiecutter.project_name}}/src/api/dummy_db/routes.py +++ b/{{cookiecutter.project_name}}/src/api/dummy_db/routes.py @@ -1,42 +1,52 @@ import uuid -from typing import List, Optional +from typing import Optional -from fastapi import APIRouter +from fastapi import APIRouter, Depends from src.api.dummy_db.schema import ( BaseDummyModel, UpdateDummyModel, - ReturnDummyModel + GetDummyResponse ) from src.models import DummyDBModel +from src.services.db.session import Session, db_session router = APIRouter() URL_PREFIX = "/dummy_db_obj" @router.put("/") -async def create_dummy(dummy_obj: BaseDummyModel) -> None: - await DummyDBModel.create(**dummy_obj.dict()) +async def create_dummy(dummy_obj: BaseDummyModel, session: Session = Depends(db_session)) -> None: + await session.execute(DummyDBModel.create(**dummy_obj.dict())) @router.post("/{dummy_id}") -async def update_dummy_model(dummy_id: uuid.UUID, new_values: UpdateDummyModel) -> None: - await DummyDBModel.update(dummy_id, **new_values.dict()) +async def update_dummy_model( + dummy_id: uuid.UUID, + new_values: UpdateDummyModel, + session: Session = Depends(db_session) +) -> None: + await session.execute(DummyDBModel.update(dummy_id, **new_values.dict())) @router.delete("/{dummy_id}") -async def delete_dummy_model(dummy_id: uuid.UUID) -> None: - await DummyDBModel.delete(dummy_id) +async def delete_dummy_model(dummy_id: uuid.UUID, session: Session = Depends(db_session)) -> None: + await session.execute(DummyDBModel.delete(dummy_id)) -@router.get("/", response_model=List[ReturnDummyModel]) +@router.get("/", response_model=GetDummyResponse) async def filter_dummy_models( dummy_id: Optional[uuid.UUID] = None, name: Optional[str] = None, - surname: Optional[str] = None -) -> List[DummyDBModel]: - return await DummyDBModel.filter( + surname: Optional[str] = None, + session: Session = Depends(db_session) +) -> GetDummyResponse: + filter_query = DummyDBModel.filter( dummy_id=dummy_id, name=name, surname=surname ) + results = await session.fetchall(filter_query) + return GetDummyResponse( + results=results + ) diff --git a/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py b/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py index 67ed75c61c0454955cc57c944290b1956eeaed8c..08778b64e84ef8c1580f3818cb49f1358988894c 100644 --- a/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py +++ b/{{cookiecutter.project_name}}/src/api/dummy_db/schema.py @@ -1,30 +1,32 @@ import uuid -from typing import Optional +from datetime import datetime +from typing import Optional, List -from pydantic import Field -{% if cookiecutter.pg_driver == "aiopg" -%} -from pydantic import BaseConfig -{% endif %} +from pydantic import Field, BaseConfig from pydantic.main import BaseModel class BaseDummyModel(BaseModel): - name: str - surname: str + name: str = Field(example="Dummy name") + surname: str = Field(example="Dummy surname") class ReturnDummyModel(BaseDummyModel): id: uuid.UUID + created_at: datetime + updated_at: datetime - {% if cookiecutter.pg_driver == "aiopg" -%} class Config(BaseConfig): orm_mode = True - {% endif %} + + +class GetDummyResponse(BaseModel): + results: List[ReturnDummyModel] class UpdateDummyModel(BaseModel): - name: Optional[str] = Field(default=None) - surname: Optional[str] = Field(default=None) + name: Optional[str] = Field(default=None, example="New name") + surname: Optional[str] = Field(default=None, example="New surname") class DummyFiltersModel(BaseModel): diff --git a/{{cookiecutter.project_name}}/src/models/dummy_db_model.py b/{{cookiecutter.project_name}}/src/models/dummy_db_model.py index 27677ee7f9f7935c68d61970f58f66b2585e396c..2e188cf270118ffac94cf2b12143d23913ab2055 100644 --- a/{{cookiecutter.project_name}}/src/models/dummy_db_model.py +++ b/{{cookiecutter.project_name}}/src/models/dummy_db_model.py @@ -1,15 +1,11 @@ import uuid -from typing import Optional, List +from typing import Optional -from sqlalchemy import Column, String +from sqlalchemy import Column, String, sql from sqlalchemy.dialects.postgresql import UUID -{% if cookiecutter.pg_driver == "aiopg" -%} -from src.services.db import Base, db_engine -{% else %} + from src.services.db import Base -from asyncpgsa import pg -from asyncpg import Record -{% endif %} + class DummyDBModel(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -17,61 +13,41 @@ class DummyDBModel(Base): surname = Column(String, nullable=False, index=True) @classmethod - async def create( + def create( cls, *, name: str, surname: str, - ) -> None: - query = cls.insert_query( + ) -> sql.Insert: + return cls.insert_query( name=name, surname=surname, ) - {% if cookiecutter.pg_driver == "aiopg" -%} - async with db_engine.client.acquire() as conn: - await conn.execute(query) - {% else %} - await pg.execute(query) - {% endif %} - - @classmethod - async def delete(cls, dummy_id: uuid.UUID) -> None: - query = cls.delete_query().where(cls.id == dummy_id) - {% if cookiecutter.pg_driver == "aiopg" -%} - async with db_engine.client.acquire() as conn: - await conn.execute(query) - {% else %} - await pg.fetchrow(query) - {% endif %} + def delete(cls, dummy_id: uuid.UUID) -> sql.Delete: + return cls.delete_query().where(cls.id == dummy_id) @classmethod - async def update(cls, - dummy_id: uuid.UUID, - *, - name: Optional[str] = None, - surname: Optional[str] = None - ) -> None: + def update(cls, + dummy_id: uuid.UUID, + *, + name: Optional[str] = None, + surname: Optional[str] = None + ) -> sql.Update: new_values = {} if name: new_values[cls.name] = name if surname: new_values[cls.surname] = surname - query = cls.update_query().where(cls.id == dummy_id).values(new_values) - {% if cookiecutter.pg_driver == "aiopg" -%} - async with db_engine.client.acquire() as conn: - await conn.execute(query) - {% else %} - await pg.fetchrow(query) - {% endif %} + return cls.update_query().where(cls.id == dummy_id).values(new_values) @classmethod - async def filter(cls, *, - dummy_id: Optional[uuid.UUID] = None, - name: Optional[str] = None, - surname: Optional[str] = None - ) -> {% if cookiecutter.pg_driver == "aiopg" -%}List["DummyDBModel"]{% else %}List[Record]{% endif %}: + def filter(cls, *, + dummy_id: Optional[uuid.UUID] = None, + name: Optional[str] = None, + surname: Optional[str] = None + ) -> sql.Select: query = cls.select_query() if dummy_id: query = query.where(cls.id == dummy_id) @@ -79,10 +55,4 @@ class DummyDBModel(Base): query = query.where(cls.name == name) if surname: query = query.where(cls.surname == surname) - {% if cookiecutter.pg_driver == "aiopg" -%} - async with db_engine.client.acquire() as conn: - cursor = await conn.execute(query) - return await cursor.fetchall() - {% else %} - return await pg.fetch(query) - {% endif %} + return query diff --git a/{{cookiecutter.project_name}}/src/server.py b/{{cookiecutter.project_name}}/src/server.py index 8eb633a7c3dadc68c9c0bf56e752730152997d88..f451117314cf4ebd220adef39f6efbec8f7bd007 100644 --- a/{{cookiecutter.project_name}}/src/server.py +++ b/{{cookiecutter.project_name}}/src/server.py @@ -8,12 +8,7 @@ from starlette.requests import Request from starlette.responses import UJSONResponse from src.api import api_router -{% if cookiecutter.pg_driver == "aiopg" -%} from src.services.db import db_engine -{% else %} -from src.services.db import db_url -from asyncpgsa import pg -{% endif %} {% if cookiecutter.add_redis == "True" -%} from src.services.redis import redis {% endif %} @@ -43,11 +38,7 @@ async def startup() -> None: {% if cookiecutter.add_redis == "True" -%} await redis.create_pool() {% endif %} - {% if cookiecutter.pg_driver == "aiopg" -%} await db_engine.connect() - {% else %} - await pg.init(str(db_url)) - {% endif %} @app.on_event("shutdown") @@ -55,11 +46,7 @@ async def shutdown() -> None: {% if cookiecutter.add_redis == "True" -%} await redis.shutdown() {% endif %} - {% if cookiecutter.pg_driver == "aiopg" -%} await db_engine.close() - {% else %} - await pg.pool.close() - {% endif %} diff --git a/{{cookiecutter.project_name}}/src/services/db/__init__.py b/{{cookiecutter.project_name}}/src/services/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11f3cf1b6ab272a548263d0c15f56b493618a20d --- /dev/null +++ b/{{cookiecutter.project_name}}/src/services/db/__init__.py @@ -0,0 +1,5 @@ +from src.services.db.base import Base +from src.services.db.db_meta import meta +from src.services.db.engine import db_engine, db_url + +__all__ = ["db_engine", "db_url", "meta", "Base"] diff --git a/{{cookiecutter.project_name}}/src/services/db.py b/{{cookiecutter.project_name}}/src/services/db/base.py similarity index 50% rename from {{cookiecutter.project_name}}/src/services/db.py rename to {{cookiecutter.project_name}}/src/services/db/base.py index b44fe507d976c7106a16dc461922d00e263b7b37..1af33bef3b172c7c9df9d81b1f35fcdec2759cda 100644 --- a/{{cookiecutter.project_name}}/src/services/db.py +++ b/{{cookiecutter.project_name}}/src/services/db/base.py @@ -1,29 +1,13 @@ -from typing import Any, Dict +import uuid +from typing import Any, Dict, Optional, Tuple, Type, Union import sqlalchemy as sa -{% if cookiecutter.pg_driver == "aiopg" -%} -from typing import Optional -from aiopg.sa import Engine, create_engine -{% endif %} -from sqlalchemy import MetaData, Table -from sqlalchemy.engine.url import URL, make_url +from sqlalchemy import Column, Table +from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.declarative import as_declarative, declared_attr from sqlalchemy.orm.attributes import InstrumentedAttribute +from src.services.db.db_meta import meta -from src.settings import settings - -db_url = make_url( - URL( - drivername=settings.db_driver, - username=settings.postgres_user, - password=settings.postgres_password, - host=settings.postgres_host, - port=settings.postgres_port, - database=settings.postgres_db, - ) -) - -meta = MetaData() @as_declarative(metadata=meta) class Base: @@ -31,11 +15,20 @@ class Base: __name__: str __table__: Table + __table_args__: Tuple[Any, ...] @declared_attr def __tablename__(self) -> str: return self.__name__.lower() + @declared_attr + def id(self) -> sa.Column: + return Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4 + ) + @declared_attr def created_at(self) -> sa.Column: return sa.Column( @@ -54,8 +47,26 @@ class Base: ) @classmethod - def select_query(cls, *columns: InstrumentedAttribute) -> sa.sql.Select: - return sa.select(columns or [cls]) + async def get( + cls, pk: Union[uuid.UUID, str], *fields: InstrumentedAttribute + ) -> Optional[Any]: + return cls.select_query(*fields).where(cls.id == pk) + + @classmethod + async def exists(cls, pk: Union[uuid.UUID, str]) -> sa.sql.Select: + return sa.exists().where(cls.id == pk) + + @classmethod + async def delete(cls, pk: uuid.UUID) -> sa.sql.Delete: + return cls.delete_query().where(cls.id == pk) + + @classmethod + def select_query( + cls, + *columns: Union[InstrumentedAttribute, Type["Base"]], + use_labels: bool = False, + ) -> sa.sql.Select: + return sa.select(columns or [cls], use_labels=use_labels) @classmethod def insert_query(cls, **values: Any) -> sa.sql.Insert: @@ -71,27 +82,3 @@ class Base: def as_dict(self) -> Dict[str, Any]: return {c.name: getattr(self, c.key) for c in self.__table__.columns} - -{% if cookiecutter.pg_driver == "aiopg" -%} -class DBEngine: - def __init__(self, connection_url: str) -> None: - self.dsn = connection_url - self.engine: Optional[Engine] = None - - async def connect(self) -> None: - self.engine = await create_engine(dsn=self.dsn, maxsize=100) - - @property - def client(self) -> Engine: - if self.engine: - return self.engine - raise Exception("Not connected to database") - - async def close(self) -> None: - if self.engine: - self.engine.close() - await self.engine.wait_closed() - - -db_engine = DBEngine(str(db_url)) -{% endif %} diff --git a/{{cookiecutter.project_name}}/src/services/db/db_meta.py b/{{cookiecutter.project_name}}/src/services/db/db_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..5d24aa14477612afbd5eae9d614842d30fa274c5 --- /dev/null +++ b/{{cookiecutter.project_name}}/src/services/db/db_meta.py @@ -0,0 +1,3 @@ +from sqlalchemy import MetaData + +meta = MetaData() diff --git a/{{cookiecutter.project_name}}/src/services/db/engine.py b/{{cookiecutter.project_name}}/src/services/db/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..83ee83ab32711e8388d7fdc8b095bdc856dfa5ed --- /dev/null +++ b/{{cookiecutter.project_name}}/src/services/db/engine.py @@ -0,0 +1,41 @@ +from typing import Optional + +from aiopg.sa import Engine, create_engine +from sqlalchemy.engine.url import URL, make_url + +from src.settings import settings + +db_url = make_url( + URL( + drivername=settings.db_driver, + username=settings.postgres_user, + password=settings.postgres_password, + host=settings.postgres_host, + port=settings.postgres_port, + database=settings.postgres_db, + ) +) + + +class DBEngine: + def __init__(self, connection_url: str) -> None: + self.dsn = connection_url + self.engine: Optional[Engine] = None + + async def connect(self) -> None: + self.engine = await create_engine(dsn=self.dsn) + + @property + def client(self) -> Engine: + if self.engine is None: + raise ValueError("Not connected to database") + return self.engine + + async def close(self) -> None: + if self.engine: + self.engine.close() + await self.engine.wait_closed() + raise Exception("Not connected to database") + + +db_engine = DBEngine(str(db_url)) diff --git a/{{cookiecutter.project_name}}/src/services/db/session.py b/{{cookiecutter.project_name}}/src/services/db/session.py new file mode 100644 index 0000000000000000000000000000000000000000..34d15ab4faa4457d0c2fb29ff9d1ed98d0aa96a3 --- /dev/null +++ b/{{cookiecutter.project_name}}/src/services/db/session.py @@ -0,0 +1,34 @@ +from typing import Any, List, AsyncGenerator + +from aiopg.sa import SAConnection + +from src.services.db import db_engine + + +class Session: + def __init__(self, connection: SAConnection): + self.connection = connection + + async def execute(self, query: Any) -> Any: + return await self.connection.execute(query) + + async def fetchone(self, query: Any) -> Any: + cursor = await self.connection.execute(query) + return cursor.fetchone() + + async def scalar(self, query: Any) -> Any: + result = await self.fetchone(query) + return result[0] + + async def fetchall(self, query: Any) -> List[Any]: + cursor = await self.connection.execute(query) + return await cursor.fetchall() + + +async def db_session() -> AsyncGenerator[Session, None]: + connection = await db_engine.client.acquire() + session = Session(connection) + try: + yield session + finally: + await connection.close() diff --git a/{{cookiecutter.project_name}}/src/services/elastic/__init__.py b/{{cookiecutter.project_name}}/src/services/elastic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/{{cookiecutter.project_name}}/src/services/elastic/client.py b/{{cookiecutter.project_name}}/src/services/elastic/client.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/{{cookiecutter.project_name}}/src/services/elastic/mixin.py b/{{cookiecutter.project_name}}/src/services/elastic/mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/{{cookiecutter.project_name}}/src/services/elastic/schema.py b/{{cookiecutter.project_name}}/src/services/elastic/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391