Refactor MQTT client initialization to validate host and port, and add connection testing; move command execution to a new utility module
This commit is contained in:
parent
f61f7618bc
commit
e4a0b66252
3 changed files with 45 additions and 12 deletions
|
@ -1,6 +1,6 @@
|
|||
[Unit]
|
||||
Description=Hasspy
|
||||
After=network.target
|
||||
After=network-online.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
|
|
|
@ -2,16 +2,18 @@ import io
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from subprocess import run
|
||||
from threading import Thread, Timer
|
||||
from typing import Any, Mapping, Tuple
|
||||
from typing import Any, Mapping
|
||||
|
||||
from paho.mqtt.client import Client, MQTTMessage, MQTTMessageInfo
|
||||
from paho.mqtt.enums import CallbackAPIVersion, MQTTErrorCode
|
||||
from PIL import Image
|
||||
|
||||
from .utils import run_command, test_connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -31,6 +33,17 @@ class HassClient(Client):
|
|||
|
||||
self.cover = ""
|
||||
|
||||
if not isinstance(self.config.get("host"), str):
|
||||
log.error("Host was not set correctly")
|
||||
return
|
||||
|
||||
if not isinstance(self.config.get("port"), int):
|
||||
log.error("Port was not set correctly")
|
||||
return
|
||||
|
||||
while not test_connection(self.config["host"], self.config["port"]):
|
||||
time.sleep(5)
|
||||
|
||||
self.connect()
|
||||
|
||||
def connect(self, *args: Any, **kwargs: Any) -> MQTTErrorCode:
|
||||
|
@ -38,7 +51,7 @@ class HassClient(Client):
|
|||
self.will_set(self.availability_topic, "offline", retain=True)
|
||||
|
||||
return super().connect(
|
||||
self.config.get("host", ""), self.config.get("port", 1883)
|
||||
self.config.get("host", ""), self.config.get("port", 1883), *args, **kwargs
|
||||
)
|
||||
|
||||
def publish(self, *args: Any, **kwargs: Any) -> MQTTMessageInfo:
|
||||
|
@ -413,11 +426,3 @@ class HassUserClient(HassClient):
|
|||
|
||||
by.seek(0)
|
||||
self.publish(self.cover_topic, by.read())
|
||||
|
||||
|
||||
def run_command(cmd: list[str]) -> Tuple[int, str]:
|
||||
proc = run(cmd, capture_output=True)
|
||||
if proc.returncode != 0:
|
||||
return proc.returncode, ""
|
||||
|
||||
return proc.returncode, proc.stdout.decode("utf-8")
|
||||
|
|
28
hasspy/utils.py
Normal file
28
hasspy/utils.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import logging
|
||||
import socket
|
||||
from subprocess import run
|
||||
from typing import Tuple
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_command(cmd: list[str]) -> Tuple[int, str]:
|
||||
log.debug(f"Running command {' '.join(cmd)}")
|
||||
proc = run(cmd, capture_output=True)
|
||||
if proc.returncode != 0:
|
||||
return proc.returncode, ""
|
||||
|
||||
return proc.returncode, proc.stdout.decode("utf-8")
|
||||
|
||||
|
||||
def test_connection(host: str, port: int) -> bool:
|
||||
log.debug(f"Testing connection to {host}:{port}")
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
sock.close()
|
||||
return True
|
||||
except socket.error:
|
||||
log.warning(f"Could not reach {host}:{port}")
|
||||
return False
|
Loading…
Add table
Reference in a new issue