From e4a0b662529dffcc7a7f7e6ec32e8b8d0a5dbc45 Mon Sep 17 00:00:00 2001 From: "Edgar P. Burkhart" <git@edgarpierre.fr> Date: Mon, 10 Mar 2025 09:16:32 +0100 Subject: [PATCH] Refactor MQTT client initialization to validate host and port, and add connection testing; move command execution to a new utility module --- hasspy.service | 2 +- hasspy/mqtt.py | 27 ++++++++++++++++----------- hasspy/utils.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 12 deletions(-) create mode 100644 hasspy/utils.py diff --git a/hasspy.service b/hasspy.service index ee0b781..d5e7815 100644 --- a/hasspy.service +++ b/hasspy.service @@ -1,6 +1,6 @@ [Unit] Description=Hasspy -After=network.target +After=network-online.target [Service] Type=simple diff --git a/hasspy/mqtt.py b/hasspy/mqtt.py index e15fd6d..0facb2d 100644 --- a/hasspy/mqtt.py +++ b/hasspy/mqtt.py @@ -2,16 +2,18 @@ import io import json import logging import re +import time from datetime import datetime, timezone from pathlib import Path -from subprocess import run from threading import Thread, Timer -from typing import Any, Mapping, Tuple +from typing import Any, Mapping from paho.mqtt.client import Client, MQTTMessage, MQTTMessageInfo from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode from PIL import Image +from .utils import run_command, test_connection + log = logging.getLogger(__name__) @@ -31,6 +33,17 @@ class HassClient(Client): self.cover = "" + if not isinstance(self.config.get("host"), str): + log.error("Host was not set correctly") + return + + if not isinstance(self.config.get("port"), int): + log.error("Port was not set correctly") + return + + while not test_connection(self.config["host"], self.config["port"]): + time.sleep(5) + self.connect() def connect(self, *args: Any, **kwargs: Any) -> MQTTErrorCode: @@ -38,7 +51,7 @@ class HassClient(Client): self.will_set(self.availability_topic, "offline", retain=True) return super().connect( - self.config.get("host", ""), self.config.get("port", 1883) + self.config.get("host", ""), self.config.get("port", 1883), *args, **kwargs ) def publish(self, *args: Any, **kwargs: Any) -> MQTTMessageInfo: @@ -413,11 +426,3 @@ class HassUserClient(HassClient): by.seek(0) self.publish(self.cover_topic, by.read()) - - -def run_command(cmd: list[str]) -> Tuple[int, str]: - proc = run(cmd, capture_output=True) - if proc.returncode != 0: - return proc.returncode, "" - - return proc.returncode, proc.stdout.decode("utf-8") diff --git a/hasspy/utils.py b/hasspy/utils.py new file mode 100644 index 0000000..b7b50bd --- /dev/null +++ b/hasspy/utils.py @@ -0,0 +1,28 @@ +import logging +import socket +from subprocess import run +from typing import Tuple + +log = logging.getLogger(__name__) + + +def run_command(cmd: list[str]) -> Tuple[int, str]: + log.debug(f"Running command {' '.join(cmd)}") + proc = run(cmd, capture_output=True) + if proc.returncode != 0: + return proc.returncode, "" + + return proc.returncode, proc.stdout.decode("utf-8") + + +def test_connection(host: str, port: int) -> bool: + log.debug(f"Testing connection to {host}:{port}") + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(1) + try: + sock.connect((host, port)) + sock.close() + return True + except socket.error: + log.warning(f"Could not reach {host}:{port}") + return False