Compare commits
5 commits
8083330032
...
7ac0a4106c
| Author | SHA1 | Date | |
|---|---|---|---|
|
7ac0a4106c |
|||
|
4103c0ca5f |
|||
|
44d484cfc1 |
|||
|
0331dd6406 |
|||
|
1a50d67df6 |
6 changed files with 263 additions and 11 deletions
|
|
@ -6,6 +6,7 @@ requires-python = ">=3.12"
|
|||
dependencies = [
|
||||
"aiohttp>=3.13.5",
|
||||
"fastapi>=0.136.1",
|
||||
"paho-mqtt>=2.1.0",
|
||||
"simple-openid-connect>=2.4.0",
|
||||
"uvicorn>=0.46.0",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -52,6 +52,12 @@ def main():
|
|||
default=os.environ.get("DOORIS_CCUJACK_URL", "https://hmdooris-ccu.ccchh.net:2122"),
|
||||
help="The URL under which a CCUJACK instance is hosted that actually operates the locks",
|
||||
)
|
||||
argp.add_argument(
|
||||
"--ccujack-mqtt",
|
||||
required=False,
|
||||
default=os.environ.get("DOORIS_CCUJACK_MQTT", "hmdooris-ccu.ccchh.net:1883"),
|
||||
help="The $HOSTNAME:$PORT of the CCUJack embedded MQTT server",
|
||||
)
|
||||
argp.add_argument(
|
||||
"--ccujack-user",
|
||||
required="DOORIS_CCUJACK_USER" not in os.environ,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Optional, List
|
||||
from typing import Optional, List, AsyncIterable
|
||||
import logging
|
||||
import secrets
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime, UTC
|
||||
from fastapi import FastAPI, Request, Response, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.sse import EventSourceResponse
|
||||
from contextlib import asynccontextmanager
|
||||
from simple_openid_connect.client import OpenidClient
|
||||
from simple_openid_connect.data import TokenSuccessResponse, RpInitiatedLogoutRequest
|
||||
|
|
@ -36,14 +36,19 @@ async def lifespan(app: FastAPI):
|
|||
scope=app_cfg.openid_scope,
|
||||
)
|
||||
|
||||
# TODO: regularly re-query CCUJACK to discover new locks
|
||||
app.extra["ccujack"] = CCUJackClient(
|
||||
base_uri=app_cfg.ccujack_url,
|
||||
auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password)
|
||||
auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password),
|
||||
mqtt_conn=app_cfg.ccujack_mqtt,
|
||||
)
|
||||
await app.extra["ccujack"].connect_mqtt()
|
||||
await app.extra["ccujack"].find_locks()
|
||||
|
||||
yield
|
||||
|
||||
await app.extra["ccujack"].close_connections()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Dooris",
|
||||
|
|
@ -68,9 +73,7 @@ async def get_user_info(
|
|||
) -> models.UserStatus:
|
||||
return models.UserStatus(
|
||||
is_authorized=current_user.may_operate_locks,
|
||||
guaranteed_session_until=datetime.fromtimestamp(
|
||||
current_user.id_token.exp, UTC
|
||||
),
|
||||
guaranteed_session_until=datetime.fromtimestamp(current_user.id_token.exp, UTC),
|
||||
username=current_user.id_token.preferred_username,
|
||||
ccchh_roles=current_user.ccchh_roles,
|
||||
)
|
||||
|
|
@ -119,7 +122,9 @@ async def login_init(
|
|||
response_class=RedirectResponse,
|
||||
status_code=302,
|
||||
)
|
||||
async def login_callback(req: Request, resp: Response, oidc_client: deps.OpenidClient) -> str:
|
||||
async def login_callback(
|
||||
req: Request, resp: Response, oidc_client: deps.OpenidClient
|
||||
) -> str:
|
||||
# check that the user is currently in an authenticating state
|
||||
# these cookies are set by the login_init() view
|
||||
if (
|
||||
|
|
@ -237,6 +242,18 @@ async def list_locks(ccujack: deps.CCUJackClient) -> List[models.Lock]:
|
|||
return result
|
||||
|
||||
|
||||
@app.get(
|
||||
"/api/locks/stream",
|
||||
tags=["locks"],
|
||||
responses={status.HTTP_401_UNAUTHORIZED: {"model": models.HttpProblemDetail}},
|
||||
response_class=EventSourceResponse,
|
||||
)
|
||||
async def watch_locks(ccujack: deps.CCUJackClient) -> AsyncIterable[List[models.Lock]]:
|
||||
while True:
|
||||
yield await list_locks(ccujack)
|
||||
await ccujack.data_updated.wait()
|
||||
|
||||
|
||||
@app.patch(
|
||||
"/api/locks/{lock_id}",
|
||||
tags=["locks"],
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from typing import List, Tuple, Optional, Any
|
||||
from typing import List, Tuple, Optional, Any, Dict
|
||||
from aiohttp import ClientSession, BasicAuth, TCPConnector
|
||||
import logging
|
||||
import asyncio
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dooris_api.mqtt_client import AsyncMqttClient
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -68,21 +70,41 @@ LockData = List[Tuple[CCUDeviceInfo, List[Tuple[CCUChannelInfo, List[CCUParamInf
|
|||
class CCUJackClient:
|
||||
base_uri: str
|
||||
locks: LockData
|
||||
param_values: Dict[str, Any]
|
||||
task_process_messages: asyncio.Task
|
||||
data_updated: asyncio.Event
|
||||
|
||||
def __init__(self, base_uri: str, auth: BasicAuth):
|
||||
def __init__(self, base_uri: str, auth: BasicAuth, mqtt_conn: str):
|
||||
self.http = ClientSession(
|
||||
base_url=base_uri,
|
||||
auth=auth,
|
||||
raise_for_status=True,
|
||||
connector=TCPConnector(ssl=False),
|
||||
)
|
||||
self.mqtt = AsyncMqttClient(mqtt_conn, auth.login, auth.password)
|
||||
self.locks = None
|
||||
self.param_values = dict()
|
||||
self.task_process_messages = None
|
||||
self.data_updated = asyncio.Event()
|
||||
|
||||
async def connect_mqtt(self):
|
||||
await self.mqtt.connect()
|
||||
self.task_process_messages = asyncio.get_running_loop().create_task(
|
||||
self.process_mqt_messages(), name="process-mqtt-messages"
|
||||
)
|
||||
|
||||
async def close_connections(self):
|
||||
await asyncio.gather(self.mqtt.disconnect(), self.http.close())
|
||||
self.task_process_messages.cancel()
|
||||
self.task_process_messages = None
|
||||
|
||||
async def find_locks(self):
|
||||
logger.debug("Inspecting lock devices present in CCUJack")
|
||||
|
||||
async with self.http.get("/device") as resp:
|
||||
devices = CCUDeviceList.model_validate(await resp.json())
|
||||
|
||||
# inspect CCUJACK for locks
|
||||
device_infos = await asyncio.gather(
|
||||
*[
|
||||
self._inspect_ccu_device(i)
|
||||
|
|
@ -91,9 +113,45 @@ class CCUJackClient:
|
|||
]
|
||||
)
|
||||
|
||||
self.locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK]
|
||||
# save the result
|
||||
new_locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK]
|
||||
if new_locks != self.locks:
|
||||
self.locks = new_locks
|
||||
self.data_updated.set()
|
||||
self.data_updated.clear()
|
||||
|
||||
# update active mqtt subscriptions based on newly discovered devices
|
||||
mqtt_topics = set()
|
||||
for i_lock, lock_channels in self.locks:
|
||||
for i_channel, channel_params in lock_channels:
|
||||
for i_param in channel_params:
|
||||
mqtt_topics.add(
|
||||
f"device/status/{i_lock.address}/{i_channel.index}/{i_param.id}"
|
||||
)
|
||||
await self.mqtt.update_subscriptions(mqtt_topics)
|
||||
|
||||
async def process_mqt_messages(self):
|
||||
while True:
|
||||
try:
|
||||
msg = await self.mqtt.messages.get()
|
||||
|
||||
param_name = msg.topic.removeprefix("device/status/")
|
||||
param_value = CCUValue.model_validate_json(msg.payload)
|
||||
logger.debug(
|
||||
f"Got new value from MQTT for parameter {param_name}: {param_value}"
|
||||
)
|
||||
self.param_values[param_name] = param_value
|
||||
self.data_updated.set()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"could not process incoming mqtt message: {e}")
|
||||
finally:
|
||||
self.data_updated.clear()
|
||||
|
||||
async def query_param_value(self, address: str) -> CCUValue:
|
||||
if address in self.param_values:
|
||||
return self.param_values[address]
|
||||
|
||||
async def query_param_value(self, address: str):
|
||||
logger.debug("Querying parameter value from '%s'", address)
|
||||
async with self.http.get(f"/device/{address}/~pv") as resp:
|
||||
return CCUValue.model_validate(await resp.json())
|
||||
|
|
|
|||
159
api/src/dooris_api/mqtt_client.py
Normal file
159
api/src/dooris_api/mqtt_client.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
#
|
||||
# This whole implementation is adapted from the upstream GitHub example
|
||||
# https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py
|
||||
#
|
||||
|
||||
from typing import Any, List, Set, Iterable
|
||||
import logging
|
||||
import asyncio
|
||||
import socket
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncLooper:
|
||||
"""
|
||||
Helper class to implement loopgin with asyncio for the underlying mqtt IO
|
||||
"""
|
||||
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop, client: mqtt.Client):
|
||||
self.loop = loop
|
||||
|
||||
self.client = client
|
||||
self.client.on_socket_open = self.on_socket_open
|
||||
self.client.on_socket_close = self.on_socket_close
|
||||
self.client.on_socket_register_write = self.on_socket_register_write
|
||||
self.client.on_socket_unregister_write = self.on_socket_unregister_write
|
||||
|
||||
def on_socket_open(self, client, userdata, sock):
|
||||
def cb():
|
||||
client.loop_read()
|
||||
|
||||
self.loop.add_reader(sock, cb)
|
||||
self.task_misc = self.loop.create_task(self.misc_loop())
|
||||
|
||||
def on_socket_close(self, client, userdata, sock):
|
||||
self.loop.remove_reader(sock)
|
||||
self.task_misc.cancel()
|
||||
|
||||
def on_socket_register_write(self, client, userdata, sock):
|
||||
def cb():
|
||||
client.loop_write()
|
||||
|
||||
self.loop.add_writer(sock, cb)
|
||||
|
||||
def on_socket_unregister_write(self, client, userdata, sock):
|
||||
self.loop.remove_writer(sock)
|
||||
|
||||
async def misc_loop(self):
|
||||
while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
|
||||
class AsyncMqttClient:
|
||||
loop: asyncio.AbstractEventLoop
|
||||
connection_string: str
|
||||
looper: AsyncLooper
|
||||
client: mqtt.Client
|
||||
active_subscriptions: Set[str]
|
||||
messages: asyncio.Queue
|
||||
|
||||
def __init__(self, connection_string: str, username: str, password: str):
|
||||
self.connection_string = connection_string
|
||||
self.active_subscriptions = set()
|
||||
self.messages = asyncio.Queue()
|
||||
self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris")
|
||||
self.client.username = username
|
||||
self.client.password = password
|
||||
self.client.on_connect = self.on_connect
|
||||
self.client.on_connect_fail = self.on_connect_fail
|
||||
self.client.on_message = self.on_message
|
||||
self.client.on_disconnect = self.on_disconnect
|
||||
self.client.on_subscribe = self.on_subscribe
|
||||
self.client.on_unsubscribe = self.on_unsubscribe
|
||||
self.client.on_disconnect = self.on_disconnect
|
||||
|
||||
def on_connect(
|
||||
self,
|
||||
client: mqtt.Client,
|
||||
userdata: Any,
|
||||
flags: mqtt.ConnectFlags,
|
||||
reason_code,
|
||||
properties: mqtt.Properties,
|
||||
):
|
||||
logger.debug(f"mqtt client connected with message '{reason_code}'")
|
||||
self.fut_connected.set_result(None)
|
||||
|
||||
def on_connect_fail(self, client, userdata):
|
||||
logger.error("mqtt client could not connect to broker")
|
||||
self.fut_connected.set_exception(
|
||||
Exception("mqtt client could not connect to broker")
|
||||
)
|
||||
|
||||
def on_disconnect(
|
||||
self,
|
||||
client: mqtt.Client,
|
||||
userdata: Any,
|
||||
flags: mqtt.DisconnectFlags,
|
||||
reason_code,
|
||||
properties: mqtt.Properties,
|
||||
):
|
||||
logger.debug(f"mqtt client disconnected with message '{reason_code}'")
|
||||
if self.fut_disconnect:
|
||||
self.fut_disconnect.set_result(None)
|
||||
|
||||
def on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
|
||||
self.messages.put_nowait(msg)
|
||||
|
||||
def on_subscribe(self, client, userdata, mid, reason_code, properties):
|
||||
logger.debug(f"mqtt client subscribed to topics with message '{reason_code}'")
|
||||
self.fut_subscribe.set_result(None)
|
||||
|
||||
def on_unsubscribe(self, client, userdata, mid, reason_code, properties):
|
||||
logger.debug(
|
||||
f"mqtt client unsubscribed from topics with message '{reason_code}'"
|
||||
)
|
||||
self.fut_unsubscribe.set_result(None)
|
||||
|
||||
async def update_subscriptions(self, topics: Iterable[str], qos: int = 1):
|
||||
"""
|
||||
Update MQTT subscriptions so that the client is subscribed to exactly the given list of topics
|
||||
"""
|
||||
|
||||
to_add = topics.difference(self.active_subscriptions)
|
||||
if to_add:
|
||||
logger.info(f"mqtt client subscribing to topics {', '.join(to_add)}")
|
||||
self.fut_subscribe = asyncio.get_running_loop().create_future()
|
||||
self.client.subscribe([(i, qos) for i in to_add])
|
||||
await self.fut_subscribe
|
||||
|
||||
to_remove = self.active_subscriptions.difference(topics)
|
||||
if to_remove:
|
||||
logger.info(f"mqtt client unsubscribing from topics {','.join(to_remove)}")
|
||||
self.fut_unsubscribe = asyncio.get_running_loop().create_future()
|
||||
self.client.unsubscribe(list(to_remove))
|
||||
await self.fut_unsubscribe
|
||||
|
||||
async def connect(self):
|
||||
server_host, server_port = self.connection_string.rsplit(":", maxsplit=1)
|
||||
|
||||
self.looper = AsyncLooper(asyncio.get_running_loop(), self.client)
|
||||
|
||||
logger.info("Connecting to mqtt server at %s:%s", server_host, server_port)
|
||||
self.fut_connected = asyncio.get_running_loop().create_future()
|
||||
self.client.connect(server_host, int(server_port))
|
||||
self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
||||
|
||||
await self.fut_connected
|
||||
|
||||
async def disconnect(self):
|
||||
logger.info("Disconnecting mqtt client from broker")
|
||||
self.fut_disconnect = asyncio.get_running_loop().create_future()
|
||||
self.client.disconnect()
|
||||
await self.fut_disconnect
|
||||
11
api/uv.lock
generated
11
api/uv.lock
generated
|
|
@ -400,6 +400,7 @@ source = { editable = "." }
|
|||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "paho-mqtt" },
|
||||
{ name = "simple-openid-connect" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
|
|
@ -413,6 +414,7 @@ dev = [
|
|||
requires-dist = [
|
||||
{ name = "aiohttp", specifier = ">=3.13.5" },
|
||||
{ name = "fastapi", specifier = ">=0.136.1" },
|
||||
{ name = "paho-mqtt", specifier = ">=2.1.0" },
|
||||
{ name = "simple-openid-connect", specifier = ">=2.4.0" },
|
||||
{ name = "uvicorn", specifier = ">=0.46.0" },
|
||||
]
|
||||
|
|
@ -734,6 +736,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b2/6c/d8a02ffb24876b5f51fbd781f479fc6525a518553a4196bd0433dae9ff8e/orderedmultidict-1.0.2-py2.py3-none-any.whl", hash = "sha256:ab5044c1dca4226ae4c28524cfc5cc4c939f0b49e978efa46a6ad6468049f79b", size = 11897, upload-time = "2025-11-18T08:00:41.44Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paho-mqtt"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/39/15/0a6214e76d4d32e7f663b109cf71fb22561c2be0f701d67f93950cd40542/paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834", size = 148848, upload-time = "2024-04-29T19:52:55.591Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/cb/00451c3cf31790287768bb12c6bec834f5d292eaf3022afc88e14b8afc94/paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee", size = 67219, upload-time = "2024-04-29T19:52:48.345Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parso"
|
||||
version = "0.8.7"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue