diff --git a/api/src/dooris_api/app.py b/api/src/dooris_api/app.py index 3b53d35..32d626d 100644 --- a/api/src/dooris_api/app.py +++ b/api/src/dooris_api/app.py @@ -2,7 +2,6 @@ from typing import Optional, List import logging import secrets import sys -import os from datetime import datetime, UTC from fastapi import FastAPI, Request, Response, status from fastapi.responses import RedirectResponse @@ -42,11 +41,13 @@ async def lifespan(app: FastAPI): auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password), mqtt_conn=app_cfg.ccujack_mqtt, ) - await app.extra["ccujack"].find_locks() await app.extra["ccujack"].connect_mqtt() + await app.extra["ccujack"].find_locks() yield + await app.extra["ccujack"].close_connections() + app = FastAPI( title="Dooris", @@ -71,9 +72,7 @@ async def get_user_info( ) -> models.UserStatus: return models.UserStatus( is_authorized=current_user.may_operate_locks, - guaranteed_session_until=datetime.fromtimestamp( - current_user.id_token.exp, UTC - ), + guaranteed_session_until=datetime.fromtimestamp(current_user.id_token.exp, UTC), username=current_user.id_token.preferred_username, ccchh_roles=current_user.ccchh_roles, ) @@ -122,7 +121,9 @@ async def login_init( response_class=RedirectResponse, status_code=302, ) -async def login_callback(req: Request, resp: Response, oidc_client: deps.OpenidClient) -> str: +async def login_callback( + req: Request, resp: Response, oidc_client: deps.OpenidClient +) -> str: # check that the user is currently in an authenticating state # these cookies are set by the login_init() view if ( diff --git a/api/src/dooris_api/ccujack.py b/api/src/dooris_api/ccujack.py index 78c9e00..ed1e038 100644 --- a/api/src/dooris_api/ccujack.py +++ b/api/src/dooris_api/ccujack.py @@ -84,8 +84,16 @@ class CCUJackClient: async def connect_mqtt(self): await self.mqtt.connect() + async def close_connections(self): + await asyncio.gather( + self.mqtt.disconnect(), + self.http.close() + ) + async def find_locks(self): logger.debug("Inspecting lock devices present in CCUJack") + + # iterate through the CCUJACK API to find all devices async with self.http.get("/device") as resp: devices = CCUDeviceList.model_validate(await resp.json()) @@ -99,6 +107,14 @@ class CCUJackClient: self.locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK] + # update active mqtt subscriptions + mqtt_topics = set() + for i_lock, lock_channels in self.locks: + for i_channel, channel_params in lock_channels: + for i_param in channel_params: + mqtt_topics.add(f"device/status/{i_lock.address}/{i_channel.index}/{i_param.id}") + # await self.mqtt.update_subscriptions(mqtt_topics) + async def query_param_value(self, address: str): logger.debug("Querying parameter value from '%s'", address) async with self.http.get(f"/device/{address}/~pv") as resp: diff --git a/api/src/dooris_api/mqtt_client.py b/api/src/dooris_api/mqtt_client.py index f7cb22b..3f6de6f 100644 --- a/api/src/dooris_api/mqtt_client.py +++ b/api/src/dooris_api/mqtt_client.py @@ -1,9 +1,9 @@ # # This whole implementation is adapted from the upstream GitHub example # https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py -# +# -from typing import Any +from typing import Any, List, Set, Iterable import logging import asyncio import socket @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class AsyncLooper: """ - Helper class to implement loopgin with asyncio for the underlying mqtt IO + Helper class to implement loopgin with asyncio for the underlying mqtt IO """ def __init__(self, loop: asyncio.AbstractEventLoop, client: mqtt.Client): @@ -58,9 +58,14 @@ class AsyncLooper: class AsyncMqttClient: loop: asyncio.AbstractEventLoop - + connection_string: str + looper: AsyncLooper + client: mqtt.Client + active_subscriptions: Set[str] + def __init__(self, connection_string: str, username: str, password: str): self.connection_string = connection_string + self.active_subscriptions = set() self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris") self.client.username = username self.client.password = password @@ -68,25 +73,75 @@ class AsyncMqttClient: self.client.on_message = self.on_message self.client.on_disconnect = self.on_disconnect - def on_connect(self, client: mqtt.Client, userdata: Any, flags: mqtt.ConnectFlags, reason_code, properties: mqtt.Properties): + def on_connect( + self, + client: mqtt.Client, + userdata: Any, + flags: mqtt.ConnectFlags, + reason_code, + properties: mqtt.Properties, + ): logger.debug(f"mqtt client connected with message '{reason_code}'") + self.fut_connected.set_result(None) - def on_disconnect(self, client: mqtt.Client, userdata: Any, flags: mqtt.DisconnectFlags, reason_code, properties: mqtt.Properties): - logger.debug("mqtt client disconnected") - print("flags", type(flags), flags) - print("reason_code", type(reason_code), reason_code) - print("properties", type(properties), properties) + def on_disconnect( + self, + client: mqtt.Client, + userdata: Any, + flags: mqtt.DisconnectFlags, + reason_code, + properties: mqtt.Properties, + ): + logger.debug(f"mqtt client disconnected with message '{reason_code}'") + if self.fut_disconnect: + self.fut_disconnect.set_result(None) - def on_message(self, client: mqtt.Client, userdata: Any, msg): + def on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug("mqtt client got message") print("msg", type(msg), msg) + def on_subscribe(self, client, userdata, mid, reason_code, properties): + logger.debug(f"mqtt client subscribed to topics with message '{reason_code}'") + self.fut_subscribe.set_result(None) + + def on_unsubscribe(self, client, userdata, mid, reason_code, properties): + logger.debug(f"mqtt client unsubscribed from topics with message '{reason_code}'") + self.fut_unsubscribe.set_result(None) + + async def update_subscriptions(self, topics: Iterable[str]): + """ + Update MQTT subscriptions so that the client is subscribed to exactly the given list of topics + """ + + to_add = topics.difference(self.active_subscriptions) + if to_add: + logger.info(f"mqtt client subscribing to topics {', '.join(to_add)}") + qos = 2 + self.fut_subscribe = asyncio.get_running_loop().create_future() + self.client.subscribe([(i, qos) for i in to_add]) + await self.fut_subscribe + + to_remove = self.active_subscriptions.difference(topics) + if to_remove: + logger.info(f"mqtt client unsubscribing from topics {','.join(to_remove)}") + self.fut_unsubscribe = asyncio.get_running_loop().create_future() + self.client.unsubscribe(list(to_remove)) + await self.fut_unsubscribe + async def connect(self): server_host, server_port = self.connection_string.rsplit(":", maxsplit=1) - - looper = AsyncLooper(asyncio.get_running_loop(), self.client) + + self.looper = AsyncLooper(asyncio.get_running_loop(), self.client) logger.info("Connecting to mqtt server at %s:%s", server_host, server_port) + self.fut_connected = asyncio.get_running_loop().create_future() self.client.connect(server_host, int(server_port)) - self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) + self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) + await self.fut_connected + + async def disconnect(self): + logger.info("Disconnecting mqtt client from broker") + self.fut_disconnect = asyncio.get_running_loop().create_future() + self.client.disconnect() + await self.fut_disconnect