Implement Server-Sent-Events API #3

Merged
lilly merged 5 commits from lilly/sse into main 2026-05-19 14:54:49 +02:00
6 changed files with 263 additions and 11 deletions

View file

@ -6,6 +6,7 @@ requires-python = ">=3.12"
dependencies = [ dependencies = [
"aiohttp>=3.13.5", "aiohttp>=3.13.5",
"fastapi>=0.136.1", "fastapi>=0.136.1",
"paho-mqtt>=2.1.0",
"simple-openid-connect>=2.4.0", "simple-openid-connect>=2.4.0",
"uvicorn>=0.46.0", "uvicorn>=0.46.0",
] ]

View file

@ -52,6 +52,12 @@ def main():
default=os.environ.get("DOORIS_CCUJACK_URL", "https://hmdooris-ccu.ccchh.net:2122"), default=os.environ.get("DOORIS_CCUJACK_URL", "https://hmdooris-ccu.ccchh.net:2122"),
help="The URL under which a CCUJACK instance is hosted that actually operates the locks", help="The URL under which a CCUJACK instance is hosted that actually operates the locks",
) )
argp.add_argument(
"--ccujack-mqtt",
required=False,
default=os.environ.get("DOORIS_CCUJACK_MQTT", "hmdooris-ccu.ccchh.net:1883"),
help="The $HOSTNAME:$PORT of the CCUJack embedded MQTT server",
)
argp.add_argument( argp.add_argument(
"--ccujack-user", "--ccujack-user",
required="DOORIS_CCUJACK_USER" not in os.environ, required="DOORIS_CCUJACK_USER" not in os.environ,

View file

@ -1,11 +1,11 @@
from typing import Optional, List from typing import Optional, List, AsyncIterable
import logging import logging
import secrets import secrets
import sys import sys
import os
from datetime import datetime, UTC from datetime import datetime, UTC
from fastapi import FastAPI, Request, Response, status from fastapi import FastAPI, Request, Response, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.sse import EventSourceResponse
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from simple_openid_connect.client import OpenidClient from simple_openid_connect.client import OpenidClient
from simple_openid_connect.data import TokenSuccessResponse, RpInitiatedLogoutRequest from simple_openid_connect.data import TokenSuccessResponse, RpInitiatedLogoutRequest
@ -36,14 +36,19 @@ async def lifespan(app: FastAPI):
scope=app_cfg.openid_scope, scope=app_cfg.openid_scope,
) )
# TODO: regularly re-query CCUJACK to discover new locks
app.extra["ccujack"] = CCUJackClient( app.extra["ccujack"] = CCUJackClient(
base_uri=app_cfg.ccujack_url, base_uri=app_cfg.ccujack_url,
auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password) auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password),
mqtt_conn=app_cfg.ccujack_mqtt,
) )
await app.extra["ccujack"].connect_mqtt()
await app.extra["ccujack"].find_locks() await app.extra["ccujack"].find_locks()
yield yield
await app.extra["ccujack"].close_connections()
app = FastAPI( app = FastAPI(
title="Dooris", title="Dooris",
@ -68,9 +73,7 @@ async def get_user_info(
) -> models.UserStatus: ) -> models.UserStatus:
return models.UserStatus( return models.UserStatus(
is_authorized=current_user.may_operate_locks, is_authorized=current_user.may_operate_locks,
guaranteed_session_until=datetime.fromtimestamp( guaranteed_session_until=datetime.fromtimestamp(current_user.id_token.exp, UTC),
current_user.id_token.exp, UTC
),
username=current_user.id_token.preferred_username, username=current_user.id_token.preferred_username,
ccchh_roles=current_user.ccchh_roles, ccchh_roles=current_user.ccchh_roles,
) )
@ -119,7 +122,9 @@ async def login_init(
response_class=RedirectResponse, response_class=RedirectResponse,
status_code=302, status_code=302,
) )
async def login_callback(req: Request, resp: Response, oidc_client: deps.OpenidClient) -> str: async def login_callback(
req: Request, resp: Response, oidc_client: deps.OpenidClient
) -> str:
# check that the user is currently in an authenticating state # check that the user is currently in an authenticating state
# these cookies are set by the login_init() view # these cookies are set by the login_init() view
if ( if (
@ -237,6 +242,18 @@ async def list_locks(ccujack: deps.CCUJackClient) -> List[models.Lock]:
return result return result
@app.get(
"/api/locks/stream",
tags=["locks"],
responses={status.HTTP_401_UNAUTHORIZED: {"model": models.HttpProblemDetail}},
response_class=EventSourceResponse,
)
async def watch_locks(ccujack: deps.CCUJackClient) -> AsyncIterable[List[models.Lock]]:
while True:
yield await list_locks(ccujack)
await ccujack.data_updated.wait()
@app.patch( @app.patch(
"/api/locks/{lock_id}", "/api/locks/{lock_id}",
tags=["locks"], tags=["locks"],

View file

@ -1,9 +1,11 @@
from typing import List, Tuple, Optional, Any from typing import List, Tuple, Optional, Any, Dict
from aiohttp import ClientSession, BasicAuth, TCPConnector from aiohttp import ClientSession, BasicAuth, TCPConnector
import logging import logging
import asyncio import asyncio
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from dooris_api.mqtt_client import AsyncMqttClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,21 +70,41 @@ LockData = List[Tuple[CCUDeviceInfo, List[Tuple[CCUChannelInfo, List[CCUParamInf
class CCUJackClient: class CCUJackClient:
base_uri: str base_uri: str
locks: LockData locks: LockData
param_values: Dict[str, Any]
task_process_messages: asyncio.Task
data_updated: asyncio.Event
def __init__(self, base_uri: str, auth: BasicAuth): def __init__(self, base_uri: str, auth: BasicAuth, mqtt_conn: str):
self.http = ClientSession( self.http = ClientSession(
base_url=base_uri, base_url=base_uri,
auth=auth, auth=auth,
raise_for_status=True, raise_for_status=True,
connector=TCPConnector(ssl=False), connector=TCPConnector(ssl=False),
) )
self.mqtt = AsyncMqttClient(mqtt_conn, auth.login, auth.password)
self.locks = None self.locks = None
self.param_values = dict()
self.task_process_messages = None
self.data_updated = asyncio.Event()
async def connect_mqtt(self):
await self.mqtt.connect()
self.task_process_messages = asyncio.get_running_loop().create_task(
self.process_mqt_messages(), name="process-mqtt-messages"
)
async def close_connections(self):
await asyncio.gather(self.mqtt.disconnect(), self.http.close())
self.task_process_messages.cancel()
self.task_process_messages = None
async def find_locks(self): async def find_locks(self):
logger.debug("Inspecting lock devices present in CCUJack") logger.debug("Inspecting lock devices present in CCUJack")
async with self.http.get("/device") as resp: async with self.http.get("/device") as resp:
devices = CCUDeviceList.model_validate(await resp.json()) devices = CCUDeviceList.model_validate(await resp.json())
# inspect CCUJACK for locks
device_infos = await asyncio.gather( device_infos = await asyncio.gather(
*[ *[
self._inspect_ccu_device(i) self._inspect_ccu_device(i)
@ -91,9 +113,45 @@ class CCUJackClient:
] ]
) )
self.locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK] # save the result
new_locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK]
if new_locks != self.locks:
self.locks = new_locks
self.data_updated.set()
self.data_updated.clear()
# update active mqtt subscriptions based on newly discovered devices
mqtt_topics = set()
for i_lock, lock_channels in self.locks:
for i_channel, channel_params in lock_channels:
for i_param in channel_params:
mqtt_topics.add(
f"device/status/{i_lock.address}/{i_channel.index}/{i_param.id}"
)
await self.mqtt.update_subscriptions(mqtt_topics)
async def process_mqt_messages(self):
while True:
try:
msg = await self.mqtt.messages.get()
param_name = msg.topic.removeprefix("device/status/")
param_value = CCUValue.model_validate_json(msg.payload)
logger.debug(
f"Got new value from MQTT for parameter {param_name}: {param_value}"
)
self.param_values[param_name] = param_value
self.data_updated.set()
except Exception as e:
logger.exception(f"could not process incoming mqtt message: {e}")
finally:
self.data_updated.clear()
async def query_param_value(self, address: str) -> CCUValue:
if address in self.param_values:
return self.param_values[address]
async def query_param_value(self, address: str):
logger.debug("Querying parameter value from '%s'", address) logger.debug("Querying parameter value from '%s'", address)
async with self.http.get(f"/device/{address}/~pv") as resp: async with self.http.get(f"/device/{address}/~pv") as resp:
return CCUValue.model_validate(await resp.json()) return CCUValue.model_validate(await resp.json())

View file

@ -0,0 +1,159 @@
#
# This whole implementation is adapted from the upstream GitHub example
# https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py
#
from typing import Any, List, Set, Iterable
import logging
import asyncio
import socket
import paho.mqtt.client as mqtt
logger = logging.getLogger(__name__)
class AsyncLooper:
"""
Helper class to implement loopgin with asyncio for the underlying mqtt IO
"""
def __init__(self, loop: asyncio.AbstractEventLoop, client: mqtt.Client):
self.loop = loop
self.client = client
self.client.on_socket_open = self.on_socket_open
self.client.on_socket_close = self.on_socket_close
self.client.on_socket_register_write = self.on_socket_register_write
self.client.on_socket_unregister_write = self.on_socket_unregister_write
def on_socket_open(self, client, userdata, sock):
def cb():
client.loop_read()
self.loop.add_reader(sock, cb)
self.task_misc = self.loop.create_task(self.misc_loop())
def on_socket_close(self, client, userdata, sock):
self.loop.remove_reader(sock)
self.task_misc.cancel()
def on_socket_register_write(self, client, userdata, sock):
def cb():
client.loop_write()
self.loop.add_writer(sock, cb)
def on_socket_unregister_write(self, client, userdata, sock):
self.loop.remove_writer(sock)
async def misc_loop(self):
while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
break
class AsyncMqttClient:
loop: asyncio.AbstractEventLoop
connection_string: str
looper: AsyncLooper
client: mqtt.Client
active_subscriptions: Set[str]
messages: asyncio.Queue
def __init__(self, connection_string: str, username: str, password: str):
self.connection_string = connection_string
self.active_subscriptions = set()
self.messages = asyncio.Queue()
self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris")
self.client.username = username
self.client.password = password
self.client.on_connect = self.on_connect
self.client.on_connect_fail = self.on_connect_fail
self.client.on_message = self.on_message
self.client.on_disconnect = self.on_disconnect
self.client.on_subscribe = self.on_subscribe
self.client.on_unsubscribe = self.on_unsubscribe
self.client.on_disconnect = self.on_disconnect
def on_connect(
self,
client: mqtt.Client,
userdata: Any,
flags: mqtt.ConnectFlags,
reason_code,
properties: mqtt.Properties,
):
logger.debug(f"mqtt client connected with message '{reason_code}'")
self.fut_connected.set_result(None)
def on_connect_fail(self, client, userdata):
logger.error("mqtt client could not connect to broker")
self.fut_connected.set_exception(
Exception("mqtt client could not connect to broker")
)
def on_disconnect(
self,
client: mqtt.Client,
userdata: Any,
flags: mqtt.DisconnectFlags,
reason_code,
properties: mqtt.Properties,
):
logger.debug(f"mqtt client disconnected with message '{reason_code}'")
if self.fut_disconnect:
self.fut_disconnect.set_result(None)
def on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
self.messages.put_nowait(msg)
def on_subscribe(self, client, userdata, mid, reason_code, properties):
logger.debug(f"mqtt client subscribed to topics with message '{reason_code}'")
self.fut_subscribe.set_result(None)
def on_unsubscribe(self, client, userdata, mid, reason_code, properties):
logger.debug(
f"mqtt client unsubscribed from topics with message '{reason_code}'"
)
self.fut_unsubscribe.set_result(None)
async def update_subscriptions(self, topics: Iterable[str], qos: int = 1):
"""
Update MQTT subscriptions so that the client is subscribed to exactly the given list of topics
"""
to_add = topics.difference(self.active_subscriptions)
if to_add:
logger.info(f"mqtt client subscribing to topics {', '.join(to_add)}")
self.fut_subscribe = asyncio.get_running_loop().create_future()
self.client.subscribe([(i, qos) for i in to_add])
await self.fut_subscribe
to_remove = self.active_subscriptions.difference(topics)
if to_remove:
logger.info(f"mqtt client unsubscribing from topics {','.join(to_remove)}")
self.fut_unsubscribe = asyncio.get_running_loop().create_future()
self.client.unsubscribe(list(to_remove))
await self.fut_unsubscribe
async def connect(self):
server_host, server_port = self.connection_string.rsplit(":", maxsplit=1)
self.looper = AsyncLooper(asyncio.get_running_loop(), self.client)
logger.info("Connecting to mqtt server at %s:%s", server_host, server_port)
self.fut_connected = asyncio.get_running_loop().create_future()
self.client.connect(server_host, int(server_port))
self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
await self.fut_connected
async def disconnect(self):
logger.info("Disconnecting mqtt client from broker")
self.fut_disconnect = asyncio.get_running_loop().create_future()
self.client.disconnect()
await self.fut_disconnect

11
api/uv.lock generated
View file

@ -400,6 +400,7 @@ source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiohttp" }, { name = "aiohttp" },
{ name = "fastapi" }, { name = "fastapi" },
{ name = "paho-mqtt" },
{ name = "simple-openid-connect" }, { name = "simple-openid-connect" },
{ name = "uvicorn" }, { name = "uvicorn" },
] ]
@ -413,6 +414,7 @@ dev = [
requires-dist = [ requires-dist = [
{ name = "aiohttp", specifier = ">=3.13.5" }, { name = "aiohttp", specifier = ">=3.13.5" },
{ name = "fastapi", specifier = ">=0.136.1" }, { name = "fastapi", specifier = ">=0.136.1" },
{ name = "paho-mqtt", specifier = ">=2.1.0" },
{ name = "simple-openid-connect", specifier = ">=2.4.0" }, { name = "simple-openid-connect", specifier = ">=2.4.0" },
{ name = "uvicorn", specifier = ">=0.46.0" }, { name = "uvicorn", specifier = ">=0.46.0" },
] ]
@ -734,6 +736,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b2/6c/d8a02ffb24876b5f51fbd781f479fc6525a518553a4196bd0433dae9ff8e/orderedmultidict-1.0.2-py2.py3-none-any.whl", hash = "sha256:ab5044c1dca4226ae4c28524cfc5cc4c939f0b49e978efa46a6ad6468049f79b", size = 11897, upload-time = "2025-11-18T08:00:41.44Z" }, { url = "https://files.pythonhosted.org/packages/b2/6c/d8a02ffb24876b5f51fbd781f479fc6525a518553a4196bd0433dae9ff8e/orderedmultidict-1.0.2-py2.py3-none-any.whl", hash = "sha256:ab5044c1dca4226ae4c28524cfc5cc4c939f0b49e978efa46a6ad6468049f79b", size = 11897, upload-time = "2025-11-18T08:00:41.44Z" },
] ]
[[package]]
name = "paho-mqtt"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/39/15/0a6214e76d4d32e7f663b109cf71fb22561c2be0f701d67f93950cd40542/paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834", size = 148848, upload-time = "2024-04-29T19:52:55.591Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c4/cb/00451c3cf31790287768bb12c6bec834f5d292eaf3022afc88e14b8afc94/paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee", size = 67219, upload-time = "2024-04-29T19:52:48.345Z" },
]
[[package]] [[package]]
name = "parso" name = "parso"
version = "0.8.7" version = "0.8.7"