implement automatic login refresh using OIDC refresh tokens

This commit is contained in:
lilly 2026-05-04 11:01:04 +02:00
commit 94fac19546
Signed by: lilly
SSH key fingerprint: SHA256:y9T5GFw2A20WVklhetIxG1+kcg/Ce0shnQmbu1LQ37g
3 changed files with 65 additions and 26 deletions

View file

@ -3,6 +3,6 @@ from dooris_api.app import app
def main(): def main():
import uvicorn import uvicorn
config = uvicorn.Config(app, port=8000, log_level="info") config = uvicorn.Config(app, port=8000, log_level="debug")
server = uvicorn.Server(config) server = uvicorn.Server(config)
server.run() server.run()

View file

@ -1,8 +1,8 @@
from typing import Optional from typing import Optional
import logging import logging
import secrets import secrets
import math import sys
from datetime import datetime, UTC, timedelta from datetime import datetime, UTC
from fastapi import FastAPI, Request, Response from fastapi import FastAPI, Request, Response
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -18,6 +18,12 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
root_logger = logging.getLogger("")
root_logger.setLevel(logging.INFO)
root_logger.addHandler(logging.StreamHandler(sys.stderr))
app_logger = logging.getLogger("dooris_api")
app_logger.setLevel(logging.DEBUG)
app.extra["oidc_client"] = OpenidClient.from_issuer_url( app.extra["oidc_client"] = OpenidClient.from_issuer_url(
url="https://id.hamburg.ccc.de/realms/test/", url="https://id.hamburg.ccc.de/realms/test/",
authentication_redirect_uri="http://localhost:8000/auth/login-callback", authentication_redirect_uri="http://localhost:8000/auth/login-callback",
@ -84,25 +90,13 @@ async def login_callback(req: Request, resp: Response, oidc_client: deps.OpenidC
# save the authentication result for later reuse # save the authentication result for later reuse
if isinstance(auth_result, TokenSuccessResponse): if isinstance(auth_result, TokenSuccessResponse):
now = datetime.now(UTC)
auth_start_time = datetime.fromtimestamp(float(req.cookies["auth_start_time"]), UTC) auth_start_time = datetime.fromtimestamp(float(req.cookies["auth_start_time"]), UTC)
deps.persist_auth_state(oidc_client, resp, auth_result, auth_start_time, req.cookies["auth_nonce"])
# extract the ID token now to validate its authenticity and properly set the cookie lifetime
id_token = oidc_client.decode_id_token(auth_result.id_token, nonce=req.cookies["auth_nonce"])
# calculate how long each token is valid
at_max_age = auth_start_time - now + timedelta(seconds=auth_result.expires_in)
rt_max_age = auth_start_time - now + timedelta(seconds=auth_result.refresh_expires_in)
id_max_age = datetime.fromtimestamp(id_token.exp, UTC) - now
# update cookies
resp.set_cookie("access_token", auth_result.access_token, max_age=int(at_max_age.total_seconds()), httponly=True, secure=True)
resp.set_cookie("refresh_token", auth_result.refresh_token, max_age=int(rt_max_age.total_seconds()), httponly=True, secure=True)
resp.set_cookie("id_token", auth_result.id_token, max_age=int(id_max_age.total_seconds()), httponly=True, secure=True)
resp.set_cookie("auth_nonce", req.cookies["auth_nonce"], max_age=int(id_max_age.total_seconds()), httponly=True, secure=True)
# redirect the user to the page they wanted to visit # redirect the user to the page they wanted to visit
return {"authenticated": True} return {"authenticated": True}
else: else:
return {"authenticated": False, "error": auth_result} return {"authenticated": False, "error": auth_result}

View file

@ -1,10 +1,16 @@
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import Request, Depends import logging
from datetime import datetime, UTC, timedelta
from fastapi import Request, Depends, Response
from simple_openid_connect.data import TokenSuccessResponse
from simple_openid_connect.client import OpenidClient from simple_openid_connect.client import OpenidClient
from dooris_api import models from dooris_api import models
logger = logging.getLogger(__name__)
async def get_oidc_client(req: Request) -> OpenidClient: async def get_oidc_client(req: Request) -> OpenidClient:
return req.app.extra["oidc_client"] return req.app.extra["oidc_client"]
@ -12,14 +18,53 @@ async def get_oidc_client(req: Request) -> OpenidClient:
OpenidClient = Annotated[OpenidClient, Depends(get_oidc_client)] OpenidClient = Annotated[OpenidClient, Depends(get_oidc_client)]
async def get_current_user(req: Request, oidc_client: OpenidClient) -> Optional[models.CurrentUser]: async def get_current_user(req: Request, resp: Response, oidc_client: OpenidClient) -> Optional[models.CurrentUser]:
# for now we only handle the case of no expired tokens # easiest case: we still have an access token (which is the most fleeting component)
# TODO: automatically use the refresh token to fetch new access tokens # everything else should still be valid so we can just use it
if not all(i in req.cookies for i in ["access_token", "refresh_token", "id_token", "auth_nonce"]): if all(i in req.cookies for i in ("access_token", "id_token", "auth_nonce")):
return None logger.debug("user is fully authenticated, returning current user from existing id_token")
id_token = oidc_client.decode_id_token(req.cookies["id_token"], nonce=req.cookies["auth_nonce"])
return models.CurrentUser(id_token=id_token)
id_token = oidc_client.decode_id_token(req.cookies["id_token"], nonce=req.cookies["auth_nonce"]) # if we have a refresh token, try to get new tokens
return models.CurrentUser(id_token=id_token) if all(i in req.cookies for i in ("refresh_token", "auth_nonce")):
logger.debug("user has been previously authenticated, trying to recover with refresh_token")
auth_start_time = datetime.now(UTC)
token_resp = oidc_client.exchange_refresh_token(req.cookies["refresh_token"])
if isinstance(token_resp, TokenSuccessResponse):
persist_auth_state(oidc_client, resp, token_resp, auth_start_time)
# return the newly gotten info
id_token = oidc_client.decode_id_token(token_resp.id_token)
return models.CurrentUser(id_token=id_token)
# otherwise we can't meaningfully recover any user information or the user is simply not authenticated
logger.debug("no currently authenticated user")
return None
def persist_auth_state(oidc_client: OpenidClient, resp: Response, tokens: TokenSuccessResponse, auth_start_time: datetime, token_nonce: Optional[str] = None):
now = datetime.now(UTC)
# extract the ID token now to validate its authenticity and properly set the cookie lifetime
id_token = oidc_client.decode_id_token(tokens.id_token, nonce=token_nonce)
# calculate how long each token is valid
at_max_age = auth_start_time - now + timedelta(seconds=tokens.expires_in)
id_max_age = datetime.fromtimestamp(id_token.exp, UTC) - now
nonce_max_age = max(at_max_age, id_max_age)
if tokens.refresh_token is not None and tokens.refresh_expires_in is not None:
rt_max_age = auth_start_time - now + timedelta(seconds=tokens.refresh_expires_in)
nonce_max_age = max(at_max_age, rt_max_age, id_max_age)
if token_nonce is None:
nonce_max_age = timedelta(0)
# update cookies
resp.set_cookie("access_token", tokens.access_token, max_age=int(at_max_age.total_seconds()), httponly=True, secure=True)
if tokens.refresh_token is not None and tokens.refresh_expires_in is not None:
resp.set_cookie("refresh_token", tokens.refresh_token, max_age=int(rt_max_age.total_seconds()), httponly=True, secure=True)
resp.set_cookie("id_token", tokens.id_token, max_age=int(id_max_age.total_seconds()), httponly=True, secure=True)
resp.set_cookie("auth_nonce", token_nonce, max_age=int(nonce_max_age.total_seconds()), httponly=True, secure=True)
CurrentUser = Annotated[Optional[models.CurrentUser], Depends(get_current_user)] CurrentUser = Annotated[Optional[models.CurrentUser], Depends(get_current_user)]