import logging
import typing
from asyncio import current_task
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator
from fastapi_utils.camelcase import camel2snake
from sqlalchemy import inspect, literal, orm, select
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
create_async_engine,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, declared_attr
logger = logging.getLogger(__name__)
class Base:
id: typing.Any
__name__: str
@classmethod
def get_real_column_name(cls, attr_name: str) -> str:
return getattr(inspect(cls).c, attr_name).name
@declared_attr
def __tablename__(cls) -> str:
return camel2snake(cls.__name__)
@dataclass
class Database:
CONNECT_KWARGS = {
"max_overflow": 10,
"pool_pre_ping": True,
"pool_recycle": 3600,
"echo_pool": True,
}
def __init__(
self,
db_connect_url: str,
db_alias: str,
connect_kwargs: dict[str, Any],
debug: bool = True,
) -> None:
self._engine = create_async_engine(url=db_connect_url, **connect_kwargs, echo=debug)
self._db_alias = db_alias
self._async_session = async_scoped_session(
orm.sessionmaker(
autocommit=False,
autoflush=False,
class_=AsyncSession,
expire_on_commit=False,
bind=self._engine,
),
scopefunc=current_task,
)
async def create_tables_by_base(self, sqlalchemy_base: declarative_base) -> None:
async with self._engine.begin() as conn:
await conn.run_sync(sqlalchemy_base.metadata.create_all)
async def get_status(self) -> dict[str, str]:
async with self.session() as session:
db_status = await session.execute(
select(
[
literal("ready").label("status"),
literal(self._db_alias).label("name"),
],
),
)
return db_status.first()._asdict()
@asynccontextmanager
async def session(self) -> AsyncGenerator[Session, None]:
session: Session = self._async_session()
try:
yield session
except Exception:
logger.exception("Session rollback because of exception")
await session.rollback()
raise
finally:
await session.close()
await self._async_session.remove()
async def disconnect(self) -> None:
await self._engine.dispose()