159 lines
5.6 KiB
Python
159 lines
5.6 KiB
Python
#
|
|
# This whole implementation is adapted from the upstream GitHub example
|
|
# https://github.com/eclipse-paho/paho.mqtt.python/blob/master/examples/loop_asyncio.py
|
|
#
|
|
|
|
from typing import Any, List, Set, Iterable
|
|
import logging
|
|
import asyncio
|
|
import socket
|
|
|
|
import paho.mqtt.client as mqtt
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AsyncLooper:
|
|
"""
|
|
Helper class to implement loopgin with asyncio for the underlying mqtt IO
|
|
"""
|
|
|
|
def __init__(self, loop: asyncio.AbstractEventLoop, client: mqtt.Client):
|
|
self.loop = loop
|
|
|
|
self.client = client
|
|
self.client.on_socket_open = self.on_socket_open
|
|
self.client.on_socket_close = self.on_socket_close
|
|
self.client.on_socket_register_write = self.on_socket_register_write
|
|
self.client.on_socket_unregister_write = self.on_socket_unregister_write
|
|
|
|
def on_socket_open(self, client, userdata, sock):
|
|
def cb():
|
|
client.loop_read()
|
|
|
|
self.loop.add_reader(sock, cb)
|
|
self.task_misc = self.loop.create_task(self.misc_loop())
|
|
|
|
def on_socket_close(self, client, userdata, sock):
|
|
self.loop.remove_reader(sock)
|
|
self.task_misc.cancel()
|
|
|
|
def on_socket_register_write(self, client, userdata, sock):
|
|
def cb():
|
|
client.loop_write()
|
|
|
|
self.loop.add_writer(sock, cb)
|
|
|
|
def on_socket_unregister_write(self, client, userdata, sock):
|
|
self.loop.remove_writer(sock)
|
|
|
|
async def misc_loop(self):
|
|
while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
|
|
try:
|
|
await asyncio.sleep(1)
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
|
|
class AsyncMqttClient:
|
|
loop: asyncio.AbstractEventLoop
|
|
connection_string: str
|
|
looper: AsyncLooper
|
|
client: mqtt.Client
|
|
active_subscriptions: Set[str]
|
|
messages: asyncio.Queue
|
|
|
|
def __init__(self, connection_string: str, username: str, password: str):
|
|
self.connection_string = connection_string
|
|
self.active_subscriptions = set()
|
|
self.messages = asyncio.Queue()
|
|
self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="dooris")
|
|
self.client.username = username
|
|
self.client.password = password
|
|
self.client.on_connect = self.on_connect
|
|
self.client.on_connect_fail = self.on_connect_fail
|
|
self.client.on_message = self.on_message
|
|
self.client.on_disconnect = self.on_disconnect
|
|
self.client.on_subscribe = self.on_subscribe
|
|
self.client.on_unsubscribe = self.on_unsubscribe
|
|
self.client.on_disconnect = self.on_disconnect
|
|
|
|
def on_connect(
|
|
self,
|
|
client: mqtt.Client,
|
|
userdata: Any,
|
|
flags: mqtt.ConnectFlags,
|
|
reason_code,
|
|
properties: mqtt.Properties,
|
|
):
|
|
logger.debug(f"mqtt client connected with message '{reason_code}'")
|
|
self.fut_connected.set_result(None)
|
|
|
|
def on_connect_fail(self, client, userdata):
|
|
logger.error("mqtt client could not connect to broker")
|
|
self.fut_connected.set_exception(
|
|
Exception("mqtt client could not connect to broker")
|
|
)
|
|
|
|
def on_disconnect(
|
|
self,
|
|
client: mqtt.Client,
|
|
userdata: Any,
|
|
flags: mqtt.DisconnectFlags,
|
|
reason_code,
|
|
properties: mqtt.Properties,
|
|
):
|
|
logger.debug(f"mqtt client disconnected with message '{reason_code}'")
|
|
if self.fut_disconnect:
|
|
self.fut_disconnect.set_result(None)
|
|
|
|
def on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
|
|
self.messages.put_nowait(msg)
|
|
|
|
def on_subscribe(self, client, userdata, mid, reason_code, properties):
|
|
logger.debug(f"mqtt client subscribed to topics with message '{reason_code}'")
|
|
self.fut_subscribe.set_result(None)
|
|
|
|
def on_unsubscribe(self, client, userdata, mid, reason_code, properties):
|
|
logger.debug(
|
|
f"mqtt client unsubscribed from topics with message '{reason_code}'"
|
|
)
|
|
self.fut_unsubscribe.set_result(None)
|
|
|
|
async def update_subscriptions(self, topics: Iterable[str], qos: int = 1):
|
|
"""
|
|
Update MQTT subscriptions so that the client is subscribed to exactly the given list of topics
|
|
"""
|
|
|
|
to_add = topics.difference(self.active_subscriptions)
|
|
if to_add:
|
|
logger.info(f"mqtt client subscribing to topics {', '.join(to_add)}")
|
|
self.fut_subscribe = asyncio.get_running_loop().create_future()
|
|
self.client.subscribe([(i, qos) for i in to_add])
|
|
await self.fut_subscribe
|
|
|
|
to_remove = self.active_subscriptions.difference(topics)
|
|
if to_remove:
|
|
logger.info(f"mqtt client unsubscribing from topics {','.join(to_remove)}")
|
|
self.fut_unsubscribe = asyncio.get_running_loop().create_future()
|
|
self.client.unsubscribe(list(to_remove))
|
|
await self.fut_unsubscribe
|
|
|
|
async def connect(self):
|
|
server_host, server_port = self.connection_string.rsplit(":", maxsplit=1)
|
|
|
|
self.looper = AsyncLooper(asyncio.get_running_loop(), self.client)
|
|
|
|
logger.info("Connecting to mqtt server at %s:%s", server_host, server_port)
|
|
self.fut_connected = asyncio.get_running_loop().create_future()
|
|
self.client.connect(server_host, int(server_port))
|
|
self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
|
|
|
await self.fut_connected
|
|
|
|
async def disconnect(self):
|
|
logger.info("Disconnecting mqtt client from broker")
|
|
self.fut_disconnect = asyncio.get_running_loop().create_future()
|
|
self.client.disconnect()
|
|
await self.fut_disconnect
|