diff --git a/.gitignore b/.gitignore index 3931c59..a3ee3aa 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,6 @@ wheels/ .venv # BotBotBot -config.toml -wordlist.pickle +/config.toml +/wordlist.pickle +/cambai diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d59bbb6..3d48248 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,4 +19,4 @@ repos: rev: v1.15.0 hooks: - id: mypy - additional_dependencies: [mistralai, py-cord] + additional_dependencies: [mistralai, py-cord, types-requests] diff --git a/botbotbot.service b/botbotbot.service index 343c6a3..abcfd6c 100644 --- a/botbotbot.service +++ b/botbotbot.service @@ -7,7 +7,7 @@ After=multi-user.target Type=simple User=edpibu WorkingDirectory=/data/code/botbotbot.py -ExecStart=/data/code/botbotbot.py/env/bin/python -u -m botbotbot +ExecStart=/usr/bin/uv run python -u -m botbotbot [Install] WantedBy=multi-user.target diff --git a/botbotbot/__init__.py b/botbotbot/__init__.py index 0b21a27..46d82be 100644 --- a/botbotbot/__init__.py +++ b/botbotbot/__init__.py @@ -1,4 +1,3 @@ -import asyncio import logging import pickle import random @@ -7,6 +6,7 @@ import tomllib import discord from botbotbot.ai import AIBot +from botbotbot.tts import CambAI def main() -> None: @@ -38,6 +38,10 @@ def main() -> None: system_message=system_prompt, ) + cambai: CambAI | None = None + if isinstance(key := config.get("cambai_api_key"), str): + cambai = CambAI(key) + intents = discord.Intents.default() intents.members = True intents.message_content = True @@ -235,7 +239,7 @@ def main() -> None: logger.info("ERRE ALEA") @bot.listen("on_voice_state_update") - async def voice_random_nicks( + async def on_voice_state_update( member: discord.Member, before: discord.VoiceState, after: discord.VoiceState ) -> None: if before.channel is None and random.random() < 5 / 100: @@ -247,24 +251,29 @@ def main() -> None: logger.debug(after.channel) if after.channel: logger.debug(after.channel.members) - if ( - before.channel is None - and after.channel is not None - and random.random() < 5 / 100 - and bot not in after.channel.members - ): - logger.info(f"Voice connect from {member}") - source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg") - await asyncio.sleep(random.randrange(60)) + if ( + cambai is not None + and before.channel is None + and after.channel is not None + and bot not in after.channel.members + and bot.user + and member.id != bot.user.id + and random.random() < 5 / 100 + ): + logger.info("Generating tts") + script = random.choice( + [ + "Salut la jeunesse !", + f"Salut {member.display_name}, ça va bien ?", + "Allo ? À l'huile !", + ] + ) + source = await discord.FFmpegOpusAudio.from_probe(cambai.tts(script)) vo: discord.VoiceClient = await after.channel.connect() - await asyncio.sleep(random.randrange(10)) await vo.play(source, wait_finish=True) - - await asyncio.sleep(random.randrange(60)) await vo.disconnect() - logger.info("Voice disconnect") @bot.slash_command( name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI" diff --git a/botbotbot/tts.py b/botbotbot/tts.py new file mode 100644 index 0000000..dc39a1f --- /dev/null +++ b/botbotbot/tts.py @@ -0,0 +1,89 @@ +import hashlib +import logging +import pathlib +import time +from typing import Any + +import requests + +logger = logging.getLogger(__name__) + + +class CambAI: + base_url = "https://client.camb.ai/apis" + cambai_root = pathlib.Path("cambai") + + def __init__(self, apikey: str) -> None: + self.apikey = apikey + + if not self.cambai_root.is_dir(): + self.cambai_root.mkdir() + + @property + def headers(self) -> dict[str, str]: + return {"x-api-key": self.apikey} + + def tts(self, text: str) -> Any: + if (path := self.get_path(text)).exists(): + return path + + task_id = self.gen_task(text) + run_id = self.get_runid(task_id) + return self.get_iostream(text, run_id) + + def gen_task(self, text: str) -> str | None: + tts_payload = { + "text": text, + "voice_id": 20299, + "language": 1, + "age": 30, + "gender": 1, + } + + res = requests.post( + f"{self.base_url}/tts", json=tts_payload, headers=self.headers + ) + task_id = res.json().get("task_id") + if not isinstance(task_id, str): + logger.error(f"Got response {res.json()}") + return None + + return task_id + + def get_runid(self, task_id: str | None) -> int | None: + if task_id is None: + return None + + status = "PENDING" + while status == "PENDING": + res = requests.get(f"{self.base_url}/tts/{task_id}", headers=self.headers) + status = res.json()["status"] + print(f"Polling: {status}") + time.sleep(1.5) + + run_id = res.json().get("run_id") + if not isinstance(run_id, int): + return None + + return run_id + + def get_iostream(self, text: str, run_id: int | None) -> pathlib.Path | None: + if run_id is None: + return None + + res = requests.get( + f"{self.base_url}/tts-result/{run_id}", headers=self.headers, stream=True + ) + + path = self.get_path(text) + with open(path, "wb") as f: + for chunk in res.iter_content(chunk_size=1024): + f.write(chunk) + + return path + + def get_name(self, text: str) -> str: + return hashlib.sha256(text.encode()).hexdigest() + + def get_path(self, text: str) -> pathlib.Path: + return self.cambai_root.joinpath(f"{self.get_name(text)}.wav") diff --git a/pyproject.toml b/pyproject.toml index 704c48d..73588be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "mistralai>=1.6.0", "py-cord>=2.6.1", "pynacl>=1.5.0", + "requests>=2.32.3", ] [project.scripts] diff --git a/uv.lock b/uv.lock index 96e6860..8584e4e 100644 --- a/uv.lock +++ b/uv.lock @@ -136,6 +136,7 @@ dependencies = [ { name = "mistralai" }, { name = "py-cord" }, { name = "pynacl" }, + { name = "requests" }, ] [package.dev-dependencies] @@ -152,6 +153,7 @@ requires-dist = [ { name = "mistralai", specifier = ">=1.6.0" }, { name = "py-cord", specifier = ">=2.6.1" }, { name = "pynacl", specifier = ">=1.5.0" }, + { name = "requests", specifier = ">=2.32.3" }, ] [package.metadata.requires-dev] @@ -202,6 +204,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 }, ] +[[package]] +name = "charset-normalizer" +version = "3.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/16/b0/572805e227f01586461c80e0fd25d65a2115599cc9dad142fee4b747c357/charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3", size = 123188 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/94/ce8e6f63d18049672c76d07d119304e1e2d7c6098f0841b51c666e9f44a0/charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda", size = 195698 }, + { url = "https://files.pythonhosted.org/packages/24/2e/dfdd9770664aae179a96561cc6952ff08f9a8cd09a908f259a9dfa063568/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313", size = 140162 }, + { url = "https://files.pythonhosted.org/packages/24/4e/f646b9093cff8fc86f2d60af2de4dc17c759de9d554f130b140ea4738ca6/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9", size = 150263 }, + { url = "https://files.pythonhosted.org/packages/5e/67/2937f8d548c3ef6e2f9aab0f6e21001056f692d43282b165e7c56023e6dd/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b", size = 142966 }, + { url = "https://files.pythonhosted.org/packages/52/ed/b7f4f07de100bdb95c1756d3a4d17b90c1a3c53715c1a476f8738058e0fa/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11", size = 144992 }, + { url = "https://files.pythonhosted.org/packages/96/2c/d49710a6dbcd3776265f4c923bb73ebe83933dfbaa841c5da850fe0fd20b/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f", size = 147162 }, + { url = "https://files.pythonhosted.org/packages/b4/41/35ff1f9a6bd380303dea55e44c4933b4cc3c4850988927d4082ada230273/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd", size = 140972 }, + { url = "https://files.pythonhosted.org/packages/fb/43/c6a0b685fe6910d08ba971f62cd9c3e862a85770395ba5d9cad4fede33ab/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2", size = 149095 }, + { url = "https://files.pythonhosted.org/packages/4c/ff/a9a504662452e2d2878512115638966e75633519ec11f25fca3d2049a94a/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886", size = 152668 }, + { url = "https://files.pythonhosted.org/packages/6c/71/189996b6d9a4b932564701628af5cee6716733e9165af1d5e1b285c530ed/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601", size = 150073 }, + { url = "https://files.pythonhosted.org/packages/e4/93/946a86ce20790e11312c87c75ba68d5f6ad2208cfb52b2d6a2c32840d922/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd", size = 145732 }, + { url = "https://files.pythonhosted.org/packages/cd/e5/131d2fb1b0dddafc37be4f3a2fa79aa4c037368be9423061dccadfd90091/charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407", size = 95391 }, + { url = "https://files.pythonhosted.org/packages/27/f2/4f9a69cc7712b9b5ad8fdb87039fd89abba997ad5cbe690d1835d40405b0/charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971", size = 102702 }, + { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, +] + [[package]] name = "distlib" version = "0.3.9" @@ -584,6 +608,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 }, ] +[[package]] +name = "requests" +version = "2.32.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, +] + [[package]] name = "ruff" version = "0.11.2" @@ -648,6 +687,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, ] +[[package]] +name = "urllib3" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, +] + [[package]] name = "virtualenv" version = "20.29.3"