diff --git a/oin_thermostat/mqtt.py b/oin_thermostat/mqtt.py index c20dc86..934e002 100644 --- a/oin_thermostat/mqtt.py +++ b/oin_thermostat/mqtt.py @@ -1,5 +1,8 @@ import json import logging +import sys +from collections.abc import Callable +from typing import Any import paho.mqtt.client as mqtt @@ -14,7 +17,7 @@ class HAClient: self, entity: str, secondary_entities: list[str] = [], - mqtt_config: dict = dict(), + mqtt_config: dict[str, str] = dict(), ) -> None: self.entity = entity self.secondary_entities = secondary_entities @@ -54,34 +57,48 @@ class HAClient: host = self.config.get("host") port = self.config.get("port", 1883) - logger.debug(f"Connecting to <{host}> on port <{port}>") + logger.debug(f"Connecting to <{host}> on port <{port}>.") self.client.connect(host, port) self.subscribe(entity_topic(self.entity), self.state_update) - for entity in self.secondary_entities: - self.subscribe(entity_topic(entity, "state"), self.secondary_state_update) + self.subscribe( + [entity_topic(entity) for entity in self.secondary_entities], + self.secondary_state_update, + ) self.publish("homeassistant/device/oin/config", self.ha_options, retain=True) self.client.publish(self.availability_topic, "online", retain=True) - def publish(self, topic, data, **kwargs): + def publish(self, topic: str, data: Any, **kwargs) -> mqtt.MQTTMessageInfo: logger.debug(f"Sending message on topic <{topic}>: {json.dumps(data)}") - self.client.publish(topic, json.dumps(data), **kwargs) + return self.client.publish(topic, json.dumps(data), **kwargs) - def subscribe(self, topic, callback): - logger.debug(f"Subscribe to <{topic}>") - self.client.subscribe(topic) - self.client.message_callback_add(topic, callback) + def subscribe(self, topic: str | list[str], callback: Callable) -> None: + logger.debug(f"Subscribing to <{topic}>.") - def unsubscribe(self, topic): - logger.debug(f"Unsubscribe from <{topic}>") - self.client.unsubscribe(topic) + match topic: + case str(): + self.client.message_callback_add(topic, callback) + code, _ = self.client.subscribe(topic) + case list(): + for top in topic: + self.client.message_callback_add(top, callback) + code, _ = self.client.subscribe([(top, 0) for top in topic]) - def loop(self): - logger.info("Starting MQTT client loop") - self.client.loop_forever() + if code != 0: + logger.error(f"Failed subscribing to topic <{topic}> with code <{code}>.") + sys.exit(1) - def state_update(self, client: mqtt.Client, userdata, message: mqtt.MQTTMessage): + def loop(self) -> mqtt.MQTTErrorCode: + logger.info("Starting MQTT client loop.") + code = self.client.loop_forever(retry_first_connection=True) + + if code != 0: + logger.error("MQTT client loop failed with code <{code}>.") + + def state_update( + self, client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage + ) -> None: logger.debug(f"Message received on topic <{message.topic}>: {message.payload}.") subtopic = message.topic.rsplit("/", maxsplit=1)[1] @@ -109,8 +126,8 @@ class HAClient: self.selector.switch = False def secondary_state_update( - self, client: mqtt.Client, userdata, message: mqtt.MQTTMessage - ): + self, client: mqtt.Client, userdata: Any, message: mqtt.MQTTMessage + ) -> None: logger.debug(f"Message received on topic <{message.topic}>: {message.payload}.") _, grp, ent, subtopic = message.topic.split("/") @@ -119,14 +136,14 @@ class HAClient: if subtopic == "state": self.screen.secondary |= {idx: message.payload.decode()} - def send_data(self, data): - self.publish(self.state_topic, data) + def send_data(self, data: Any) -> mqtt.MQTTMessageInfo: + return self.publish(self.state_topic, data) -def parse(message): +def parse(message: mqtt.MQTTMessage) -> Any: return json.loads(message.payload.decode()) -def entity_topic(entity, subtopic="#"): +def entity_topic(entity: str, subtopic: str = "#") -> str: topic = entity.replace(".", "/") return f"homeassistant/{topic}/{subtopic}"