implement automatic login refresh using OIDC refresh tokens
This commit is contained in:
parent
d3814a359d
commit
94fac19546
3 changed files with 65 additions and 26 deletions
|
|
@ -3,6 +3,6 @@ from dooris_api.app import app
|
|||
|
||||
def main():
|
||||
import uvicorn
|
||||
config = uvicorn.Config(app, port=8000, log_level="info")
|
||||
config = uvicorn.Config(app, port=8000, log_level="debug")
|
||||
server = uvicorn.Server(config)
|
||||
server.run()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Optional
|
||||
import logging
|
||||
import secrets
|
||||
import math
|
||||
from datetime import datetime, UTC, timedelta
|
||||
import sys
|
||||
from datetime import datetime, UTC
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from contextlib import asynccontextmanager
|
||||
|
|
@ -18,6 +18,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
root_logger = logging.getLogger("")
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(logging.StreamHandler(sys.stderr))
|
||||
app_logger = logging.getLogger("dooris_api")
|
||||
app_logger.setLevel(logging.DEBUG)
|
||||
|
||||
app.extra["oidc_client"] = OpenidClient.from_issuer_url(
|
||||
url="https://id.hamburg.ccc.de/realms/test/",
|
||||
authentication_redirect_uri="http://localhost:8000/auth/login-callback",
|
||||
|
|
@ -84,25 +90,13 @@ async def login_callback(req: Request, resp: Response, oidc_client: deps.OpenidC
|
|||
|
||||
# save the authentication result for later reuse
|
||||
if isinstance(auth_result, TokenSuccessResponse):
|
||||
now = datetime.now(UTC)
|
||||
auth_start_time = datetime.fromtimestamp(float(req.cookies["auth_start_time"]), UTC)
|
||||
|
||||
# extract the ID token now to validate its authenticity and properly set the cookie lifetime
|
||||
id_token = oidc_client.decode_id_token(auth_result.id_token, nonce=req.cookies["auth_nonce"])
|
||||
|
||||
# calculate how long each token is valid
|
||||
at_max_age = auth_start_time - now + timedelta(seconds=auth_result.expires_in)
|
||||
rt_max_age = auth_start_time - now + timedelta(seconds=auth_result.refresh_expires_in)
|
||||
id_max_age = datetime.fromtimestamp(id_token.exp, UTC) - now
|
||||
|
||||
# update cookies
|
||||
resp.set_cookie("access_token", auth_result.access_token, max_age=int(at_max_age.total_seconds()), httponly=True, secure=True)
|
||||
resp.set_cookie("refresh_token", auth_result.refresh_token, max_age=int(rt_max_age.total_seconds()), httponly=True, secure=True)
|
||||
resp.set_cookie("id_token", auth_result.id_token, max_age=int(id_max_age.total_seconds()), httponly=True, secure=True)
|
||||
resp.set_cookie("auth_nonce", req.cookies["auth_nonce"], max_age=int(id_max_age.total_seconds()), httponly=True, secure=True)
|
||||
deps.persist_auth_state(oidc_client, resp, auth_result, auth_start_time, req.cookies["auth_nonce"])
|
||||
|
||||
# redirect the user to the page they wanted to visit
|
||||
return {"authenticated": True}
|
||||
else:
|
||||
return {"authenticated": False, "error": auth_result}
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
from typing import Annotated, Optional
|
||||
from fastapi import Request, Depends
|
||||
import logging
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from fastapi import Request, Depends, Response
|
||||
from simple_openid_connect.data import TokenSuccessResponse
|
||||
from simple_openid_connect.client import OpenidClient
|
||||
|
||||
from dooris_api import models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_oidc_client(req: Request) -> OpenidClient:
|
||||
return req.app.extra["oidc_client"]
|
||||
|
||||
|
|
@ -12,14 +18,53 @@ async def get_oidc_client(req: Request) -> OpenidClient:
|
|||
OpenidClient = Annotated[OpenidClient, Depends(get_oidc_client)]
|
||||
|
||||
|
||||
async def get_current_user(req: Request, oidc_client: OpenidClient) -> Optional[models.CurrentUser]:
|
||||
# for now we only handle the case of no expired tokens
|
||||
# TODO: automatically use the refresh token to fetch new access tokens
|
||||
if not all(i in req.cookies for i in ["access_token", "refresh_token", "id_token", "auth_nonce"]):
|
||||
return None
|
||||
async def get_current_user(req: Request, resp: Response, oidc_client: OpenidClient) -> Optional[models.CurrentUser]:
|
||||
# easiest case: we still have an access token (which is the most fleeting component)
|
||||
# everything else should still be valid so we can just use it
|
||||
if all(i in req.cookies for i in ("access_token", "id_token", "auth_nonce")):
|
||||
logger.debug("user is fully authenticated, returning current user from existing id_token")
|
||||
id_token = oidc_client.decode_id_token(req.cookies["id_token"], nonce=req.cookies["auth_nonce"])
|
||||
return models.CurrentUser(id_token=id_token)
|
||||
|
||||
id_token = oidc_client.decode_id_token(req.cookies["id_token"], nonce=req.cookies["auth_nonce"])
|
||||
return models.CurrentUser(id_token=id_token)
|
||||
# if we have a refresh token, try to get new tokens
|
||||
if all(i in req.cookies for i in ("refresh_token", "auth_nonce")):
|
||||
logger.debug("user has been previously authenticated, trying to recover with refresh_token")
|
||||
auth_start_time = datetime.now(UTC)
|
||||
token_resp = oidc_client.exchange_refresh_token(req.cookies["refresh_token"])
|
||||
if isinstance(token_resp, TokenSuccessResponse):
|
||||
persist_auth_state(oidc_client, resp, token_resp, auth_start_time)
|
||||
|
||||
# return the newly gotten info
|
||||
id_token = oidc_client.decode_id_token(token_resp.id_token)
|
||||
return models.CurrentUser(id_token=id_token)
|
||||
|
||||
# otherwise we can't meaningfully recover any user information or the user is simply not authenticated
|
||||
logger.debug("no currently authenticated user")
|
||||
return None
|
||||
|
||||
|
||||
def persist_auth_state(oidc_client: OpenidClient, resp: Response, tokens: TokenSuccessResponse, auth_start_time: datetime, token_nonce: Optional[str] = None):
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# extract the ID token now to validate its authenticity and properly set the cookie lifetime
|
||||
id_token = oidc_client.decode_id_token(tokens.id_token, nonce=token_nonce)
|
||||
|
||||
# calculate how long each token is valid
|
||||
at_max_age = auth_start_time - now + timedelta(seconds=tokens.expires_in)
|
||||
id_max_age = datetime.fromtimestamp(id_token.exp, UTC) - now
|
||||
nonce_max_age = max(at_max_age, id_max_age)
|
||||
if tokens.refresh_token is not None and tokens.refresh_expires_in is not None:
|
||||
rt_max_age = auth_start_time - now + timedelta(seconds=tokens.refresh_expires_in)
|
||||
nonce_max_age = max(at_max_age, rt_max_age, id_max_age)
|
||||
if token_nonce is None:
|
||||
nonce_max_age = timedelta(0)
|
||||
|
||||
# update cookies
|
||||
resp.set_cookie("access_token", tokens.access_token, max_age=int(at_max_age.total_seconds()), httponly=True, secure=True)
|
||||
if tokens.refresh_token is not None and tokens.refresh_expires_in is not None:
|
||||
resp.set_cookie("refresh_token", tokens.refresh_token, max_age=int(rt_max_age.total_seconds()), httponly=True, secure=True)
|
||||
resp.set_cookie("id_token", tokens.id_token, max_age=int(id_max_age.total_seconds()), httponly=True, secure=True)
|
||||
resp.set_cookie("auth_nonce", token_nonce, max_age=int(nonce_max_age.total_seconds()), httponly=True, secure=True)
|
||||
|
||||
|
||||
CurrentUser = Annotated[Optional[models.CurrentUser], Depends(get_current_user)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue