async def test_create_user(client, get_user_from_database):
user_data = {
"name": "Nikolai",
"email": "lol@gmail.com",
"password": "SamplePass1!",
}
resp = client.post("/api/v1/user/create_user", data=json.dumps(user_data))
data_from_resp = resp.json()
assert resp.status_code == 201
assert data_from_resp["name"] == user_data["name"]
assert data_from_resp["email"] == user_data["email"]
users_from_db = await get_user_from_database(data_from_resp["user_id"])
assert len(users_from_db) == 1
user_from_db = dict(users_from_db[0])
assert user_from_db["name"] == user_data["name"]
assert user_from_db["email"] == user_data["email"]
assert str(user_from_db["id"]) == data_from_resp["user_id"]
CLEAN_TABLES = [
"users",
]
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session", autouse=True)
async def run_migrations():
os.system("alembic init migrations")
os.system('alembic revision --autogenerate -m "test running migrations"')
os.system("alembic upgrade heads")
@pytest.fixture(scope="session")
async def async_session_test():
engine = create_async_engine(FULL_TEST_DB_URL, future=True, echo=True)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
yield async_session
@pytest.fixture(scope="function", autouse=True)
async def clean_tables(async_session_test):
"""Clean data in all tables before running test function"""
async with async_session_test() as session:
async with session.begin():
for table_for_cleaning in CLEAN_TABLES:
await session.execute(text(f"""TRUNCATE TABLE {table_for_cleaning} CASCADE;"""))
async def _get_test_db():
try:
# create async engine for interaction with database
test_engine = create_async_engine(
FULL_TEST_DB_URL, future=True, echo=True
)
# create session for the interaction with database
test_async_session = sessionmaker(
test_engine, expire_on_commit=False, class_=AsyncSession
)
yield test_async_session()
finally:
pass
@pytest.fixture(scope="function")
async def client() -> AsyncGenerator[TestClient, None]:
"""
Create a new FastAPI TestClient that uses the `db_session` fixture to override
the `get_db` dependency that is injected into routes.
"""
app.dependency_overrides[get_db] = _get_test_db
with TestClient(app) as client:
yield client
@pytest.fixture(scope="session")
async def asyncpg_pool():
pool = await asyncpg.create_pool(
"".join(FULL_TEST_DB_URL.split("+asyncpg"))
)
yield pool
pool.close()
@pytest.fixture
async def get_user_from_database(asyncpg_pool):
async def get_user_from_database_by_uuid(id: str):
async with asyncpg_pool.acquire() as connection:
return await connection.fetch(
"""SELECT * FROM users WHERE id = $1;""", id
)
return get_user_from_database_by_uuid
@pytest.fixture
async def create_user_in_database(asyncpg_pool):
async def create_user_in_database(
id: str,
name: str,
email: str,
hashed_password: str,
):
async with asyncpg_pool.acquire() as connection:
return await connection.execute(
"""INSERT INTO users VALUES ($1, $2, $3, $4)""",
id,
name,
email,
hashed_password,
)
return create_user_in_database
def create_test_auth_headers_for_user(email: str) -> dict[str, str]:
access_token = create_access_token(
data={"sub": email},
expires_delta=timedelta(minutes=15),
)
return {"Authorization": f"Bearer {access_token}"}
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4, unique=True)
balance: Mapped["Balance"] = relationship(back_populates="user", uselist=False)
name: Mapped[str] = mapped_column(nullable=False)
email: Mapped[str] = mapped_column(nullable=False, unique=True)
hashed_password: Mapped[str] = mapped_column(nullable=False)
class Balance(Base):
__tablename__ = "balances"
balance_id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4, unique=True)
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
balance_amount: Mapped[float] = mapped_column(default=0, nullable=False)
balance_currency: Mapped[str] = mapped_column(default="RUB", nullable=False)
user: Mapped["User"] = relationship(back_populates="balance", uselist=False)
transaction: Mapped["Transaction"] = relationship(back_populates="balance", uselist=False)
class Transaction(Base):
__tablename__ = "transactions"
transaction_id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4, unique=True)
balance_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("balances.balance_id", ondelete="CASCADE"))
balance: Mapped["Balance"] = relationship(back_populates="transaction", uselist=False)
transaction_amount: Mapped[float] = mapped_column()
transaction_date: Mapped[datetime] = mapped_column(default=datetime.utcnow(), onupdate=datetime.utcnow())
transaction_comment: Mapped[str] = mapped_column()