Refactor MQTT client initialization to validate host and port, and add connection testing; move command execution to a new utility module

This commit is contained in:
Edgar P. Burkhart 2025-03-10 09:16:32 +01:00
parent f61f7618bc
commit e4a0b66252
Signed by: edpibu
GPG key ID: 9833D3C5A25BD227
3 changed files with 45 additions and 12 deletions

View file

@ -1,6 +1,6 @@
[Unit]
Description=Hasspy
After=network.target
After=network-online.target
[Service]
Type=simple

View file

@ -2,16 +2,18 @@ import io
import json
import logging
import re
import time
from datetime import datetime, timezone
from pathlib import Path
from subprocess import run
from threading import Thread, Timer
from typing import Any, Mapping, Tuple
from typing import Any, Mapping
from paho.mqtt.client import Client, MQTTMessage, MQTTMessageInfo
from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode
from PIL import Image
from .utils import run_command, test_connection
log = logging.getLogger(__name__)
@ -31,6 +33,17 @@ class HassClient(Client):
self.cover = ""
if not isinstance(self.config.get("host"), str):
log.error("Host was not set correctly")
return
if not isinstance(self.config.get("port"), int):
log.error("Port was not set correctly")
return
while not test_connection(self.config["host"], self.config["port"]):
time.sleep(5)
self.connect()
def connect(self, *args: Any, **kwargs: Any) -> MQTTErrorCode:
@ -38,7 +51,7 @@ class HassClient(Client):
self.will_set(self.availability_topic, "offline", retain=True)
return super().connect(
self.config.get("host", ""), self.config.get("port", 1883)
self.config.get("host", ""), self.config.get("port", 1883), *args, **kwargs
)
def publish(self, *args: Any, **kwargs: Any) -> MQTTMessageInfo:
@ -413,11 +426,3 @@ class HassUserClient(HassClient):
by.seek(0)
self.publish(self.cover_topic, by.read())
def run_command(cmd: list[str]) -> Tuple[int, str]:
proc = run(cmd, capture_output=True)
if proc.returncode != 0:
return proc.returncode, ""
return proc.returncode, proc.stdout.decode("utf-8")

28
hasspy/utils.py Normal file
View file

@ -0,0 +1,28 @@
import logging
import socket
from subprocess import run
from typing import Tuple
log = logging.getLogger(__name__)
def run_command(cmd: list[str]) -> Tuple[int, str]:
log.debug(f"Running command {' '.join(cmd)}")
proc = run(cmd, capture_output=True)
if proc.returncode != 0:
return proc.returncode, ""
return proc.returncode, proc.stdout.decode("utf-8")
def test_connection(host: str, port: int) -> bool:
log.debug(f"Testing connection to {host}:{port}")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
try:
sock.connect((host, port))
sock.close()
return True
except socket.error:
log.warning(f"Could not reach {host}:{port}")
return False