From be0567ff0fc87a3fa684a9a00ffe5f41c3b5cc25 Mon Sep 17 00:00:00 2001
From: "Edgar P. Burkhart" <git@edgarpierre.fr>
Date: Sat, 22 Mar 2025 20:52:22 +0100
Subject: [PATCH] Update .gitignore, add requests dependency, and implement
 CambAI TTS integration

---
 .gitignore              |  5 ++-
 .pre-commit-config.yaml |  2 +-
 botbotbot.service       |  2 +-
 botbotbot/__init__.py   | 39 +++++++++++-------
 botbotbot/tts.py        | 89 +++++++++++++++++++++++++++++++++++++++++
 pyproject.toml          |  1 +
 uv.lock                 | 48 ++++++++++++++++++++++
 7 files changed, 167 insertions(+), 19 deletions(-)
 create mode 100644 botbotbot/tts.py

diff --git a/.gitignore b/.gitignore
index 3931c59..a3ee3aa 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,5 +10,6 @@ wheels/
 .venv
 
 # BotBotBot
-config.toml
-wordlist.pickle
+/config.toml
+/wordlist.pickle
+/cambai
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index d59bbb6..3d48248 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -19,4 +19,4 @@ repos:
     rev: v1.15.0
     hooks:
       - id: mypy
-        additional_dependencies: [mistralai, py-cord]
+        additional_dependencies: [mistralai, py-cord, types-requests]
diff --git a/botbotbot.service b/botbotbot.service
index 343c6a3..abcfd6c 100644
--- a/botbotbot.service
+++ b/botbotbot.service
@@ -7,7 +7,7 @@ After=multi-user.target
 Type=simple
 User=edpibu
 WorkingDirectory=/data/code/botbotbot.py
-ExecStart=/data/code/botbotbot.py/env/bin/python -u -m botbotbot
+ExecStart=/usr/bin/uv run python -u -m botbotbot
 
 [Install]
 WantedBy=multi-user.target
diff --git a/botbotbot/__init__.py b/botbotbot/__init__.py
index 0b21a27..46d82be 100644
--- a/botbotbot/__init__.py
+++ b/botbotbot/__init__.py
@@ -1,4 +1,3 @@
-import asyncio
 import logging
 import pickle
 import random
@@ -7,6 +6,7 @@ import tomllib
 import discord
 
 from botbotbot.ai import AIBot
+from botbotbot.tts import CambAI
 
 
 def main() -> None:
@@ -38,6 +38,10 @@ def main() -> None:
             system_message=system_prompt,
         )
 
+    cambai: CambAI | None = None
+    if isinstance(key := config.get("cambai_api_key"), str):
+        cambai = CambAI(key)
+
     intents = discord.Intents.default()
     intents.members = True
     intents.message_content = True
@@ -235,7 +239,7 @@ def main() -> None:
             logger.info("ERRE ALEA")
 
     @bot.listen("on_voice_state_update")
-    async def voice_random_nicks(
+    async def on_voice_state_update(
         member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
     ) -> None:
         if before.channel is None and random.random() < 5 / 100:
@@ -247,24 +251,29 @@ def main() -> None:
         logger.debug(after.channel)
         if after.channel:
             logger.debug(after.channel.members)
-        if (
-            before.channel is None
-            and after.channel is not None
-            and random.random() < 5 / 100
-            and bot not in after.channel.members
-        ):
-            logger.info(f"Voice connect from {member}")
-            source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg")
 
-            await asyncio.sleep(random.randrange(60))
+        if (
+            cambai is not None
+            and before.channel is None
+            and after.channel is not None
+            and bot not in after.channel.members
+            and bot.user
+            and member.id != bot.user.id
+            and random.random() < 5 / 100
+        ):
+            logger.info("Generating tts")
+            script = random.choice(
+                [
+                    "Salut la jeunesse !",
+                    f"Salut {member.display_name}, ça va bien ?",
+                    "Allo ? À l'huile !",
+                ]
+            )
+            source = await discord.FFmpegOpusAudio.from_probe(cambai.tts(script))
             vo: discord.VoiceClient = await after.channel.connect()
 
-            await asyncio.sleep(random.randrange(10))
             await vo.play(source, wait_finish=True)
-
-            await asyncio.sleep(random.randrange(60))
             await vo.disconnect()
-            logger.info("Voice disconnect")
 
     @bot.slash_command(
         name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI"
diff --git a/botbotbot/tts.py b/botbotbot/tts.py
new file mode 100644
index 0000000..dc39a1f
--- /dev/null
+++ b/botbotbot/tts.py
@@ -0,0 +1,89 @@
+import hashlib
+import logging
+import pathlib
+import time
+from typing import Any
+
+import requests
+
+logger = logging.getLogger(__name__)
+
+
+class CambAI:
+    base_url = "https://client.camb.ai/apis"
+    cambai_root = pathlib.Path("cambai")
+
+    def __init__(self, apikey: str) -> None:
+        self.apikey = apikey
+
+        if not self.cambai_root.is_dir():
+            self.cambai_root.mkdir()
+
+    @property
+    def headers(self) -> dict[str, str]:
+        return {"x-api-key": self.apikey}
+
+    def tts(self, text: str) -> Any:
+        if (path := self.get_path(text)).exists():
+            return path
+
+        task_id = self.gen_task(text)
+        run_id = self.get_runid(task_id)
+        return self.get_iostream(text, run_id)
+
+    def gen_task(self, text: str) -> str | None:
+        tts_payload = {
+            "text": text,
+            "voice_id": 20299,
+            "language": 1,
+            "age": 30,
+            "gender": 1,
+        }
+
+        res = requests.post(
+            f"{self.base_url}/tts", json=tts_payload, headers=self.headers
+        )
+        task_id = res.json().get("task_id")
+        if not isinstance(task_id, str):
+            logger.error(f"Got response {res.json()}")
+            return None
+
+        return task_id
+
+    def get_runid(self, task_id: str | None) -> int | None:
+        if task_id is None:
+            return None
+
+        status = "PENDING"
+        while status == "PENDING":
+            res = requests.get(f"{self.base_url}/tts/{task_id}", headers=self.headers)
+            status = res.json()["status"]
+            print(f"Polling: {status}")
+            time.sleep(1.5)
+
+        run_id = res.json().get("run_id")
+        if not isinstance(run_id, int):
+            return None
+
+        return run_id
+
+    def get_iostream(self, text: str, run_id: int | None) -> pathlib.Path | None:
+        if run_id is None:
+            return None
+
+        res = requests.get(
+            f"{self.base_url}/tts-result/{run_id}", headers=self.headers, stream=True
+        )
+
+        path = self.get_path(text)
+        with open(path, "wb") as f:
+            for chunk in res.iter_content(chunk_size=1024):
+                f.write(chunk)
+
+        return path
+
+    def get_name(self, text: str) -> str:
+        return hashlib.sha256(text.encode()).hexdigest()
+
+    def get_path(self, text: str) -> pathlib.Path:
+        return self.cambai_root.joinpath(f"{self.get_name(text)}.wav")
diff --git a/pyproject.toml b/pyproject.toml
index 704c48d..73588be 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,6 +9,7 @@ dependencies = [
     "mistralai>=1.6.0",
     "py-cord>=2.6.1",
     "pynacl>=1.5.0",
+    "requests>=2.32.3",
 ]
 
 [project.scripts]
diff --git a/uv.lock b/uv.lock
index 96e6860..8584e4e 100644
--- a/uv.lock
+++ b/uv.lock
@@ -136,6 +136,7 @@ dependencies = [
     { name = "mistralai" },
     { name = "py-cord" },
     { name = "pynacl" },
+    { name = "requests" },
 ]
 
 [package.dev-dependencies]
@@ -152,6 +153,7 @@ requires-dist = [
     { name = "mistralai", specifier = ">=1.6.0" },
     { name = "py-cord", specifier = ">=2.6.1" },
     { name = "pynacl", specifier = ">=1.5.0" },
+    { name = "requests", specifier = ">=2.32.3" },
 ]
 
 [package.metadata.requires-dev]
@@ -202,6 +204,28 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 },
 ]
 
+[[package]]
+name = "charset-normalizer"
+version = "3.4.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/16/b0/572805e227f01586461c80e0fd25d65a2115599cc9dad142fee4b747c357/charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3", size = 123188 }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/38/94/ce8e6f63d18049672c76d07d119304e1e2d7c6098f0841b51c666e9f44a0/charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda", size = 195698 },
+    { url = "https://files.pythonhosted.org/packages/24/2e/dfdd9770664aae179a96561cc6952ff08f9a8cd09a908f259a9dfa063568/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313", size = 140162 },
+    { url = "https://files.pythonhosted.org/packages/24/4e/f646b9093cff8fc86f2d60af2de4dc17c759de9d554f130b140ea4738ca6/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9", size = 150263 },
+    { url = "https://files.pythonhosted.org/packages/5e/67/2937f8d548c3ef6e2f9aab0f6e21001056f692d43282b165e7c56023e6dd/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b", size = 142966 },
+    { url = "https://files.pythonhosted.org/packages/52/ed/b7f4f07de100bdb95c1756d3a4d17b90c1a3c53715c1a476f8738058e0fa/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11", size = 144992 },
+    { url = "https://files.pythonhosted.org/packages/96/2c/d49710a6dbcd3776265f4c923bb73ebe83933dfbaa841c5da850fe0fd20b/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f", size = 147162 },
+    { url = "https://files.pythonhosted.org/packages/b4/41/35ff1f9a6bd380303dea55e44c4933b4cc3c4850988927d4082ada230273/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd", size = 140972 },
+    { url = "https://files.pythonhosted.org/packages/fb/43/c6a0b685fe6910d08ba971f62cd9c3e862a85770395ba5d9cad4fede33ab/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2", size = 149095 },
+    { url = "https://files.pythonhosted.org/packages/4c/ff/a9a504662452e2d2878512115638966e75633519ec11f25fca3d2049a94a/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886", size = 152668 },
+    { url = "https://files.pythonhosted.org/packages/6c/71/189996b6d9a4b932564701628af5cee6716733e9165af1d5e1b285c530ed/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601", size = 150073 },
+    { url = "https://files.pythonhosted.org/packages/e4/93/946a86ce20790e11312c87c75ba68d5f6ad2208cfb52b2d6a2c32840d922/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd", size = 145732 },
+    { url = "https://files.pythonhosted.org/packages/cd/e5/131d2fb1b0dddafc37be4f3a2fa79aa4c037368be9423061dccadfd90091/charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407", size = 95391 },
+    { url = "https://files.pythonhosted.org/packages/27/f2/4f9a69cc7712b9b5ad8fdb87039fd89abba997ad5cbe690d1835d40405b0/charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971", size = 102702 },
+    { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
+]
+
 [[package]]
 name = "distlib"
 version = "0.3.9"
@@ -584,6 +608,21 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
 ]
 
+[[package]]
+name = "requests"
+version = "2.32.3"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "certifi" },
+    { name = "charset-normalizer" },
+    { name = "idna" },
+    { name = "urllib3" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
+]
+
 [[package]]
 name = "ruff"
 version = "0.11.2"
@@ -648,6 +687,15 @@ wheels = [
     { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 },
 ]
 
+[[package]]
+name = "urllib3"
+version = "2.3.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
+]
+
 [[package]]
 name = "virtualenv"
 version = "20.29.3"