Refactor MQTT client initialization to validate host and port, and add connection testing; move command execution to a new utility module

This commit is contained in:
Edgar P. Burkhart 2025-03-10 09:16:32 +01:00
parent f61f7618bc
commit e4a0b66252
Signed by: edpibu
GPG key ID: 9833D3C5A25BD227
3 changed files with 45 additions and 12 deletions

View file

@ -1,6 +1,6 @@
[Unit] [Unit]
Description=Hasspy Description=Hasspy
After=network.target After=network-online.target
[Service] [Service]
Type=simple Type=simple

View file

@ -2,16 +2,18 @@ import io
import json import json
import logging import logging
import re import re
import time
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from subprocess import run
from threading import Thread, Timer from threading import Thread, Timer
from typing import Any, Mapping, Tuple from typing import Any, Mapping
from paho.mqtt.client import Client, MQTTMessage, MQTTMessageInfo from paho.mqtt.client import Client, MQTTMessage, MQTTMessageInfo
from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode
from PIL import Image from PIL import Image
from .utils import run_command, test_connection
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -31,6 +33,17 @@ class HassClient(Client):
self.cover = "" self.cover = ""
if not isinstance(self.config.get("host"), str):
log.error("Host was not set correctly")
return
if not isinstance(self.config.get("port"), int):
log.error("Port was not set correctly")
return
while not test_connection(self.config["host"], self.config["port"]):
time.sleep(5)
self.connect() self.connect()
def connect(self, *args: Any, **kwargs: Any) -> MQTTErrorCode: def connect(self, *args: Any, **kwargs: Any) -> MQTTErrorCode:
@ -38,7 +51,7 @@ class HassClient(Client):
self.will_set(self.availability_topic, "offline", retain=True) self.will_set(self.availability_topic, "offline", retain=True)
return super().connect( return super().connect(
self.config.get("host", ""), self.config.get("port", 1883) self.config.get("host", ""), self.config.get("port", 1883), *args, **kwargs
) )
def publish(self, *args: Any, **kwargs: Any) -> MQTTMessageInfo: def publish(self, *args: Any, **kwargs: Any) -> MQTTMessageInfo:
@ -413,11 +426,3 @@ class HassUserClient(HassClient):
by.seek(0) by.seek(0)
self.publish(self.cover_topic, by.read()) self.publish(self.cover_topic, by.read())
def run_command(cmd: list[str]) -> Tuple[int, str]:
proc = run(cmd, capture_output=True)
if proc.returncode != 0:
return proc.returncode, ""
return proc.returncode, proc.stdout.decode("utf-8")

28
hasspy/utils.py Normal file
View file

@ -0,0 +1,28 @@
import logging
import socket
from subprocess import run
from typing import Tuple
log = logging.getLogger(__name__)
def run_command(cmd: list[str]) -> Tuple[int, str]:
log.debug(f"Running command {' '.join(cmd)}")
proc = run(cmd, capture_output=True)
if proc.returncode != 0:
return proc.returncode, ""
return proc.returncode, proc.stdout.decode("utf-8")
def test_connection(host: str, port: int) -> bool:
log.debug(f"Testing connection to {host}:{port}")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
try:
sock.connect((host, port))
sock.close()
return True
except socket.error:
log.warning(f"Could not reach {host}:{port}")
return False