api: use proper connection shutdown for downstream services

This commit is contained in:
lilly 2026-05-19 13:09:00 +02:00
commit 44d484cfc1
Signed by: lilly
SSH key fingerprint: SHA256:y9T5GFw2A20WVklhetIxG1+kcg/Ce0shnQmbu1LQ37g
3 changed files with 94 additions and 22 deletions

View file

@ -2,7 +2,6 @@ from typing import Optional, List
import logging import logging
import secrets import secrets
import sys import sys
import os
from datetime import datetime, UTC from datetime import datetime, UTC
from fastapi import FastAPI, Request, Response, status from fastapi import FastAPI, Request, Response, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
@ -42,11 +41,13 @@ async def lifespan(app: FastAPI):
auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password), auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password),
mqtt_conn=app_cfg.ccujack_mqtt, mqtt_conn=app_cfg.ccujack_mqtt,
) )
await app.extra["ccujack"].find_locks()
await app.extra["ccujack"].connect_mqtt() await app.extra["ccujack"].connect_mqtt()
await app.extra["ccujack"].find_locks()
yield yield
await app.extra["ccujack"].close_connections()
app = FastAPI( app = FastAPI(
title="Dooris", title="Dooris",
@ -71,9 +72,7 @@ async def get_user_info(
) -> models.UserStatus: ) -> models.UserStatus:
return models.UserStatus( return models.UserStatus(
is_authorized=current_user.may_operate_locks, is_authorized=current_user.may_operate_locks,
guaranteed_session_until=datetime.fromtimestamp( guaranteed_session_until=datetime.fromtimestamp(current_user.id_token.exp, UTC),
current_user.id_token.exp, UTC
),
username=current_user.id_token.preferred_username, username=current_user.id_token.preferred_username,
ccchh_roles=current_user.ccchh_roles, ccchh_roles=current_user.ccchh_roles,
) )
@ -122,7 +121,9 @@ async def login_init(
response_class=RedirectResponse, response_class=RedirectResponse,
status_code=302, status_code=302,
) )
async def login_callback(req: Request, resp: Response, oidc_client: deps.OpenidClient) -> str: async def login_callback(
req: Request, resp: Response, oidc_client: deps.OpenidClient
) -> str:
# check that the user is currently in an authenticating state # check that the user is currently in an authenticating state
# these cookies are set by the login_init() view # these cookies are set by the login_init() view
if ( if (

View file

@ -84,8 +84,16 @@ class CCUJackClient:
async def connect_mqtt(self): async def connect_mqtt(self):
await self.mqtt.connect() await self.mqtt.connect()
async def close_connections(self):
await asyncio.gather(
self.mqtt.disconnect(),
self.http.close()
)
async def find_locks(self): async def find_locks(self):
logger.debug("Inspecting lock devices present in CCUJack") logger.debug("Inspecting lock devices present in CCUJack")
# iterate through the CCUJACK API to find all devices
async with self.http.get("/device") as resp: async with self.http.get("/device") as resp:
devices = CCUDeviceList.model_validate(await resp.json()) devices = CCUDeviceList.model_validate(await resp.json())
@ -99,6 +107,14 @@ class CCUJackClient:
self.locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK] self.locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK]
# update active mqtt subscriptions
mqtt_topics = set()
for i_lock, lock_channels in self.locks:
for i_channel, channel_params in lock_channels:
for i_param in channel_params:
mqtt_topics.add(f"device/status/{i_lock.address}/{i_channel.index}/{i_param.id}")
# await self.mqtt.update_subscriptions(mqtt_topics)
async def query_param_value(self, address: str): async def query_param_value(self, address: str):
logger.debug("Querying parameter value from '%s'", address) logger.debug("Querying parameter value from '%s'", address)
async with self.http.get(f"/device/{address}/~pv") as resp: async with self.http.get(f"/device/{address}/~pv") as resp:

View file

@ -3,7 +3,7 @@
# https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py # https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py
# #
from typing import Any from typing import Any, List, Set, Iterable
import logging import logging
import asyncio import asyncio
import socket import socket
@ -58,9 +58,14 @@ class AsyncLooper:
class AsyncMqttClient: class AsyncMqttClient:
loop: asyncio.AbstractEventLoop loop: asyncio.AbstractEventLoop
connection_string: str
looper: AsyncLooper
client: mqtt.Client
active_subscriptions: Set[str]
def __init__(self, connection_string: str, username: str, password: str): def __init__(self, connection_string: str, username: str, password: str):
self.connection_string = connection_string self.connection_string = connection_string
self.active_subscriptions = set()
self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris") self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris")
self.client.username = username self.client.username = username
self.client.password = password self.client.password = password
@ -68,25 +73,75 @@ class AsyncMqttClient:
self.client.on_message = self.on_message self.client.on_message = self.on_message
self.client.on_disconnect = self.on_disconnect self.client.on_disconnect = self.on_disconnect
def on_connect(self, client: mqtt.Client, userdata: Any, flags: mqtt.ConnectFlags, reason_code, properties: mqtt.Properties): def on_connect(
self,
client: mqtt.Client,
userdata: Any,
flags: mqtt.ConnectFlags,
reason_code,
properties: mqtt.Properties,
):
logger.debug(f"mqtt client connected with message '{reason_code}'") logger.debug(f"mqtt client connected with message '{reason_code}'")
self.fut_connected.set_result(None)
def on_disconnect(self, client: mqtt.Client, userdata: Any, flags: mqtt.DisconnectFlags, reason_code, properties: mqtt.Properties): def on_disconnect(
logger.debug("mqtt client disconnected") self,
print("flags", type(flags), flags) client: mqtt.Client,
print("reason_code", type(reason_code), reason_code) userdata: Any,
print("properties", type(properties), properties) flags: mqtt.DisconnectFlags,
reason_code,
properties: mqtt.Properties,
):
logger.debug(f"mqtt client disconnected with message '{reason_code}'")
if self.fut_disconnect:
self.fut_disconnect.set_result(None)
def on_message(self, client: mqtt.Client, userdata: Any, msg): def on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
logger.debug("mqtt client got message") logger.debug("mqtt client got message")
print("msg", type(msg), msg) print("msg", type(msg), msg)
def on_subscribe(self, client, userdata, mid, reason_code, properties):
logger.debug(f"mqtt client subscribed to topics with message '{reason_code}'")
self.fut_subscribe.set_result(None)
def on_unsubscribe(self, client, userdata, mid, reason_code, properties):
logger.debug(f"mqtt client unsubscribed from topics with message '{reason_code}'")
self.fut_unsubscribe.set_result(None)
async def update_subscriptions(self, topics: Iterable[str]):
"""
Update MQTT subscriptions so that the client is subscribed to exactly the given list of topics
"""
to_add = topics.difference(self.active_subscriptions)
if to_add:
logger.info(f"mqtt client subscribing to topics {', '.join(to_add)}")
qos = 2
self.fut_subscribe = asyncio.get_running_loop().create_future()
self.client.subscribe([(i, qos) for i in to_add])
await self.fut_subscribe
to_remove = self.active_subscriptions.difference(topics)
if to_remove:
logger.info(f"mqtt client unsubscribing from topics {','.join(to_remove)}")
self.fut_unsubscribe = asyncio.get_running_loop().create_future()
self.client.unsubscribe(list(to_remove))
await self.fut_unsubscribe
async def connect(self): async def connect(self):
server_host, server_port = self.connection_string.rsplit(":", maxsplit=1) server_host, server_port = self.connection_string.rsplit(":", maxsplit=1)
looper = AsyncLooper(asyncio.get_running_loop(), self.client) self.looper = AsyncLooper(asyncio.get_running_loop(), self.client)
logger.info("Connecting to mqtt server at %s:%s", server_host, server_port) logger.info("Connecting to mqtt server at %s:%s", server_host, server_port)
self.fut_connected = asyncio.get_running_loop().create_future()
self.client.connect(server_host, int(server_port)) self.client.connect(server_host, int(server_port))
self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
await self.fut_connected
async def disconnect(self):
logger.info("Disconnecting mqtt client from broker")
self.fut_disconnect = asyncio.get_running_loop().create_future()
self.client.disconnect()
await self.fut_disconnect