api: use proper connection shutdown for downstream services
This commit is contained in:
parent
0331dd6406
commit
44d484cfc1
3 changed files with 94 additions and 22 deletions
|
|
@ -2,7 +2,6 @@ from typing import Optional, List
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
import sys
|
import sys
|
||||||
import os
|
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from fastapi import FastAPI, Request, Response, status
|
from fastapi import FastAPI, Request, Response, status
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
@ -42,11 +41,13 @@ async def lifespan(app: FastAPI):
|
||||||
auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password),
|
auth=BasicAuth(app_cfg.ccujack_user, app_cfg.ccujack_password),
|
||||||
mqtt_conn=app_cfg.ccujack_mqtt,
|
mqtt_conn=app_cfg.ccujack_mqtt,
|
||||||
)
|
)
|
||||||
await app.extra["ccujack"].find_locks()
|
|
||||||
await app.extra["ccujack"].connect_mqtt()
|
await app.extra["ccujack"].connect_mqtt()
|
||||||
|
await app.extra["ccujack"].find_locks()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
await app.extra["ccujack"].close_connections()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Dooris",
|
title="Dooris",
|
||||||
|
|
@ -71,9 +72,7 @@ async def get_user_info(
|
||||||
) -> models.UserStatus:
|
) -> models.UserStatus:
|
||||||
return models.UserStatus(
|
return models.UserStatus(
|
||||||
is_authorized=current_user.may_operate_locks,
|
is_authorized=current_user.may_operate_locks,
|
||||||
guaranteed_session_until=datetime.fromtimestamp(
|
guaranteed_session_until=datetime.fromtimestamp(current_user.id_token.exp, UTC),
|
||||||
current_user.id_token.exp, UTC
|
|
||||||
),
|
|
||||||
username=current_user.id_token.preferred_username,
|
username=current_user.id_token.preferred_username,
|
||||||
ccchh_roles=current_user.ccchh_roles,
|
ccchh_roles=current_user.ccchh_roles,
|
||||||
)
|
)
|
||||||
|
|
@ -122,7 +121,9 @@ async def login_init(
|
||||||
response_class=RedirectResponse,
|
response_class=RedirectResponse,
|
||||||
status_code=302,
|
status_code=302,
|
||||||
)
|
)
|
||||||
async def login_callback(req: Request, resp: Response, oidc_client: deps.OpenidClient) -> str:
|
async def login_callback(
|
||||||
|
req: Request, resp: Response, oidc_client: deps.OpenidClient
|
||||||
|
) -> str:
|
||||||
# check that the user is currently in an authenticating state
|
# check that the user is currently in an authenticating state
|
||||||
# these cookies are set by the login_init() view
|
# these cookies are set by the login_init() view
|
||||||
if (
|
if (
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,16 @@ class CCUJackClient:
|
||||||
async def connect_mqtt(self):
|
async def connect_mqtt(self):
|
||||||
await self.mqtt.connect()
|
await self.mqtt.connect()
|
||||||
|
|
||||||
|
async def close_connections(self):
|
||||||
|
await asyncio.gather(
|
||||||
|
self.mqtt.disconnect(),
|
||||||
|
self.http.close()
|
||||||
|
)
|
||||||
|
|
||||||
async def find_locks(self):
|
async def find_locks(self):
|
||||||
logger.debug("Inspecting lock devices present in CCUJack")
|
logger.debug("Inspecting lock devices present in CCUJack")
|
||||||
|
|
||||||
|
# iterate through the CCUJACK API to find all devices
|
||||||
async with self.http.get("/device") as resp:
|
async with self.http.get("/device") as resp:
|
||||||
devices = CCUDeviceList.model_validate(await resp.json())
|
devices = CCUDeviceList.model_validate(await resp.json())
|
||||||
|
|
||||||
|
|
@ -99,6 +107,14 @@ class CCUJackClient:
|
||||||
|
|
||||||
self.locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK]
|
self.locks = [i for i in device_infos if i[0].type == DEVICE_TYPE_LOCK]
|
||||||
|
|
||||||
|
# update active mqtt subscriptions
|
||||||
|
mqtt_topics = set()
|
||||||
|
for i_lock, lock_channels in self.locks:
|
||||||
|
for i_channel, channel_params in lock_channels:
|
||||||
|
for i_param in channel_params:
|
||||||
|
mqtt_topics.add(f"device/status/{i_lock.address}/{i_channel.index}/{i_param.id}")
|
||||||
|
# await self.mqtt.update_subscriptions(mqtt_topics)
|
||||||
|
|
||||||
async def query_param_value(self, address: str):
|
async def query_param_value(self, address: str):
|
||||||
logger.debug("Querying parameter value from '%s'", address)
|
logger.debug("Querying parameter value from '%s'", address)
|
||||||
async with self.http.get(f"/device/{address}/~pv") as resp:
|
async with self.http.get(f"/device/{address}/~pv") as resp:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
#
|
#
|
||||||
# This whole implementation is adapted from the upstream GitHub example
|
# This whole implementation is adapted from the upstream GitHub example
|
||||||
# https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py
|
# https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py
|
||||||
#
|
#
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, List, Set, Iterable
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
import socket
|
||||||
|
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class AsyncLooper:
|
class AsyncLooper:
|
||||||
"""
|
"""
|
||||||
Helper class to implement loopgin with asyncio for the underlying mqtt IO
|
Helper class to implement loopgin with asyncio for the underlying mqtt IO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, loop: asyncio.AbstractEventLoop, client: mqtt.Client):
|
def __init__(self, loop: asyncio.AbstractEventLoop, client: mqtt.Client):
|
||||||
|
|
@ -58,9 +58,14 @@ class AsyncLooper:
|
||||||
|
|
||||||
class AsyncMqttClient:
|
class AsyncMqttClient:
|
||||||
loop: asyncio.AbstractEventLoop
|
loop: asyncio.AbstractEventLoop
|
||||||
|
connection_string: str
|
||||||
|
looper: AsyncLooper
|
||||||
|
client: mqtt.Client
|
||||||
|
active_subscriptions: Set[str]
|
||||||
|
|
||||||
def __init__(self, connection_string: str, username: str, password: str):
|
def __init__(self, connection_string: str, username: str, password: str):
|
||||||
self.connection_string = connection_string
|
self.connection_string = connection_string
|
||||||
|
self.active_subscriptions = set()
|
||||||
self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris")
|
self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris")
|
||||||
self.client.username = username
|
self.client.username = username
|
||||||
self.client.password = password
|
self.client.password = password
|
||||||
|
|
@ -68,25 +73,75 @@ class AsyncMqttClient:
|
||||||
self.client.on_message = self.on_message
|
self.client.on_message = self.on_message
|
||||||
self.client.on_disconnect = self.on_disconnect
|
self.client.on_disconnect = self.on_disconnect
|
||||||
|
|
||||||
def on_connect(self, client: mqtt.Client, userdata: Any, flags: mqtt.ConnectFlags, reason_code, properties: mqtt.Properties):
|
def on_connect(
|
||||||
|
self,
|
||||||
|
client: mqtt.Client,
|
||||||
|
userdata: Any,
|
||||||
|
flags: mqtt.ConnectFlags,
|
||||||
|
reason_code,
|
||||||
|
properties: mqtt.Properties,
|
||||||
|
):
|
||||||
logger.debug(f"mqtt client connected with message '{reason_code}'")
|
logger.debug(f"mqtt client connected with message '{reason_code}'")
|
||||||
|
self.fut_connected.set_result(None)
|
||||||
|
|
||||||
def on_disconnect(self, client: mqtt.Client, userdata: Any, flags: mqtt.DisconnectFlags, reason_code, properties: mqtt.Properties):
|
def on_disconnect(
|
||||||
logger.debug("mqtt client disconnected")
|
self,
|
||||||
print("flags", type(flags), flags)
|
client: mqtt.Client,
|
||||||
print("reason_code", type(reason_code), reason_code)
|
userdata: Any,
|
||||||
print("properties", type(properties), properties)
|
flags: mqtt.DisconnectFlags,
|
||||||
|
reason_code,
|
||||||
|
properties: mqtt.Properties,
|
||||||
|
):
|
||||||
|
logger.debug(f"mqtt client disconnected with message '{reason_code}'")
|
||||||
|
if self.fut_disconnect:
|
||||||
|
self.fut_disconnect.set_result(None)
|
||||||
|
|
||||||
def on_message(self, client: mqtt.Client, userdata: Any, msg):
|
def on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
|
||||||
logger.debug("mqtt client got message")
|
logger.debug("mqtt client got message")
|
||||||
print("msg", type(msg), msg)
|
print("msg", type(msg), msg)
|
||||||
|
|
||||||
|
def on_subscribe(self, client, userdata, mid, reason_code, properties):
|
||||||
|
logger.debug(f"mqtt client subscribed to topics with message '{reason_code}'")
|
||||||
|
self.fut_subscribe.set_result(None)
|
||||||
|
|
||||||
|
def on_unsubscribe(self, client, userdata, mid, reason_code, properties):
|
||||||
|
logger.debug(f"mqtt client unsubscribed from topics with message '{reason_code}'")
|
||||||
|
self.fut_unsubscribe.set_result(None)
|
||||||
|
|
||||||
|
async def update_subscriptions(self, topics: Iterable[str]):
|
||||||
|
"""
|
||||||
|
Update MQTT subscriptions so that the client is subscribed to exactly the given list of topics
|
||||||
|
"""
|
||||||
|
|
||||||
|
to_add = topics.difference(self.active_subscriptions)
|
||||||
|
if to_add:
|
||||||
|
logger.info(f"mqtt client subscribing to topics {', '.join(to_add)}")
|
||||||
|
qos = 2
|
||||||
|
self.fut_subscribe = asyncio.get_running_loop().create_future()
|
||||||
|
self.client.subscribe([(i, qos) for i in to_add])
|
||||||
|
await self.fut_subscribe
|
||||||
|
|
||||||
|
to_remove = self.active_subscriptions.difference(topics)
|
||||||
|
if to_remove:
|
||||||
|
logger.info(f"mqtt client unsubscribing from topics {','.join(to_remove)}")
|
||||||
|
self.fut_unsubscribe = asyncio.get_running_loop().create_future()
|
||||||
|
self.client.unsubscribe(list(to_remove))
|
||||||
|
await self.fut_unsubscribe
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
server_host, server_port = self.connection_string.rsplit(":", maxsplit=1)
|
server_host, server_port = self.connection_string.rsplit(":", maxsplit=1)
|
||||||
|
|
||||||
looper = AsyncLooper(asyncio.get_running_loop(), self.client)
|
self.looper = AsyncLooper(asyncio.get_running_loop(), self.client)
|
||||||
|
|
||||||
logger.info("Connecting to mqtt server at %s:%s", server_host, server_port)
|
logger.info("Connecting to mqtt server at %s:%s", server_host, server_port)
|
||||||
|
self.fut_connected = asyncio.get_running_loop().create_future()
|
||||||
self.client.connect(server_host, int(server_port))
|
self.client.connect(server_host, int(server_port))
|
||||||
self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
||||||
|
|
||||||
|
await self.fut_connected
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
logger.info("Disconnecting mqtt client from broker")
|
||||||
|
self.fut_disconnect = asyncio.get_running_loop().create_future()
|
||||||
|
self.client.disconnect()
|
||||||
|
await self.fut_disconnect
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue