Skip to content

Example Application

# You'll have to create and import
# your own user models, functions, db
# and other dependencies.
# Personal recommendation is to use
# FastAPI Users and SQLAlchemy if you
# don't need a fully customized user
# management system.
from database import async_session_maker
from users import (
    User,
    UserRead,
    UserCreate,
    fastapi_users,
    auth_backend,
    current_active_user
)

# Import FastAPI and Two-Fast-Auth
from fastapi import (
    FastAPI,
    Body,
    Depends,
    Header,
    HTTPException,
    status
)
from fastapi.responses import StreamingResponse
from two_fast_auth import (
    TwoFactorMiddleware,
    TwoFactorAuth
)


# Create FastAPI instance
app = FastAPI()


# Get user secret callback for middleware
async def get_user_secret_callback(
    user_id: str
) -> str:
    async with async_session_maker() as session:
        user = await session.get(
            User,
            user_id
        )
        return user.two_fa_secret if user else None


# Add 2FA middleware
app.add_middleware(
    TwoFactorMiddleware,
    get_user_secret_callback=get_user_secret_callback,
    excluded_paths=[
        "/docs", # Swagger UI
        "/openapi.json", # OpenAPI JSON
        "/redoc", # Redoc UI
        "/auth/jwt/login", # FastAPI Users login
        "/auth/register", # FastAPI Users register
        "/setup-2fa", # Example endpoint
        "/verify-2fa", # Example endpoint
        "/recovery-codes" # Example endpoint
    ]
)


# FastAPI Users Routers
## Auth Router
app.include_router(
    fastapi_users.get_auth_router(
        auth_backend
    ),
    prefix="/auth/jwt",
    tags=["Auth"],
)


## Register Router
app.include_router(
    fastapi_users.get_register_router(
        UserRead,
        UserCreate
    ),
    prefix="/auth",
    tags=["Auth"],
)


# Endpoints examples
## Setup 2FA
@app.post(
    "/setup-2fa",
    tags=["Auth"]
)
async def setup_2fa(
    user: User = Depends(current_active_user)
):
    # Generate new 2FA secret
    two_fa = TwoFactorAuth()

    # Generate QR code
    qr_code = two_fa.generate_qr_code(
        user.email or user.username
    )

    # Generate recovery codes
    recovery_codes = TwoFactorAuth.generate_recovery_codes()

    # Store secret and recovery codes in DB
    async with async_session_maker() as session:
        db_user = await session.get(
            User,
            user.id
        )
        db_user.two_fa_secret = two_fa.secret
        db_user.recovery_codes = recovery_codes
        session.add(db_user)
        await session.commit()
        await session.refresh(db_user)

    # Return QR code as image stream and other data in headers
    return StreamingResponse(
        qr_code,
        media_type="image/png",
        headers={
            "X-2FA-Secret": two_fa.secret,
            "X-Recovery-Codes": ",".join(recovery_codes)
        }
    )


## Verify 2FA code
@app.post(
    "/verify-2fa",
    tags=["Auth"]
)
async def verify_2fa(
    code: str = Body(
        ...,
        embed=True
    ),
    user: User = Depends(current_active_user)
):
    if not user.two_fa_secret:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="2FA not set up"
        )

    two_fa = TwoFactorAuth(user.two_fa_secret)
    if two_fa.verify_code(code):
        return {"success": True}
    return {"success": False}


## Use recovery code
@app.post(
    "/recovery-codes",
    tags=["Auth"]
)
async def use_recovery_code(
    code: str = Body(
        ...,
        embed=True
    ),
    user: User = Depends(current_active_user)
):
    async with async_session_maker() as session:
        # Get fresh user instance from database
        db_user = await session.get(
            User,
            user.id
        )

        if not db_user.recovery_codes or code not in db_user.recovery_codes:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Invalid recovery code"
            )

        # Remove used recovery code
        updated_codes = [
            c
            for c in db_user.recovery_codes
            if c != code
        ]
        db_user.recovery_codes = updated_codes
        session.add(db_user)
        await session.commit()

    return {"success": True}


## 2FA protected route example
@app.get(
    "/protected-route",
    tags=["Protected"]
)
async def protected_route(
    user: User = Depends(current_active_user),
    x_2fa_code: str = Header(
        ...,
        alias="X-2FA-Code"
    )
):
    # Explicit 2FA verification for demonstration
    if not user.two_fa_secret:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="2FA not configured"
        )

    if not TwoFactorAuth(user.two_fa_secret).verify_code(x_2fa_code):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid 2FA code"
        )

    return {
        "message": "You've accessed a protected route with valid 2FA!",
        "user_id": str(user.id)
    }


# Run the API server
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=8000
    )