Update .gitignore, add requests dependency, and implement CambAI TTS integration
This commit is contained in:
parent
bc9d5a8943
commit
be0567ff0f
7 changed files with 167 additions and 19 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -10,5 +10,6 @@ wheels/
|
|||
.venv
|
||||
|
||||
# BotBotBot
|
||||
config.toml
|
||||
wordlist.pickle
|
||||
/config.toml
|
||||
/wordlist.pickle
|
||||
/cambai
|
||||
|
|
|
@ -19,4 +19,4 @@ repos:
|
|||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies: [mistralai, py-cord]
|
||||
additional_dependencies: [mistralai, py-cord, types-requests]
|
||||
|
|
|
@ -7,7 +7,7 @@ After=multi-user.target
|
|||
Type=simple
|
||||
User=edpibu
|
||||
WorkingDirectory=/data/code/botbotbot.py
|
||||
ExecStart=/data/code/botbotbot.py/env/bin/python -u -m botbotbot
|
||||
ExecStart=/usr/bin/uv run python -u -m botbotbot
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
|
@ -7,6 +6,7 @@ import tomllib
|
|||
import discord
|
||||
|
||||
from botbotbot.ai import AIBot
|
||||
from botbotbot.tts import CambAI
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
@ -38,6 +38,10 @@ def main() -> None:
|
|||
system_message=system_prompt,
|
||||
)
|
||||
|
||||
cambai: CambAI | None = None
|
||||
if isinstance(key := config.get("cambai_api_key"), str):
|
||||
cambai = CambAI(key)
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.members = True
|
||||
intents.message_content = True
|
||||
|
@ -235,7 +239,7 @@ def main() -> None:
|
|||
logger.info("ERRE ALEA")
|
||||
|
||||
@bot.listen("on_voice_state_update")
|
||||
async def voice_random_nicks(
|
||||
async def on_voice_state_update(
|
||||
member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
|
||||
) -> None:
|
||||
if before.channel is None and random.random() < 5 / 100:
|
||||
|
@ -247,24 +251,29 @@ def main() -> None:
|
|||
logger.debug(after.channel)
|
||||
if after.channel:
|
||||
logger.debug(after.channel.members)
|
||||
if (
|
||||
before.channel is None
|
||||
and after.channel is not None
|
||||
and random.random() < 5 / 100
|
||||
and bot not in after.channel.members
|
||||
):
|
||||
logger.info(f"Voice connect from {member}")
|
||||
source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg")
|
||||
|
||||
await asyncio.sleep(random.randrange(60))
|
||||
if (
|
||||
cambai is not None
|
||||
and before.channel is None
|
||||
and after.channel is not None
|
||||
and bot not in after.channel.members
|
||||
and bot.user
|
||||
and member.id != bot.user.id
|
||||
and random.random() < 5 / 100
|
||||
):
|
||||
logger.info("Generating tts")
|
||||
script = random.choice(
|
||||
[
|
||||
"Salut la jeunesse !",
|
||||
f"Salut {member.display_name}, ça va bien ?",
|
||||
"Allo ? À l'huile !",
|
||||
]
|
||||
)
|
||||
source = await discord.FFmpegOpusAudio.from_probe(cambai.tts(script))
|
||||
vo: discord.VoiceClient = await after.channel.connect()
|
||||
|
||||
await asyncio.sleep(random.randrange(10))
|
||||
await vo.play(source, wait_finish=True)
|
||||
|
||||
await asyncio.sleep(random.randrange(60))
|
||||
await vo.disconnect()
|
||||
logger.info("Voice disconnect")
|
||||
|
||||
@bot.slash_command(
|
||||
name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI"
|
||||
|
|
89
botbotbot/tts.py
Normal file
89
botbotbot/tts.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import hashlib
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CambAI:
|
||||
base_url = "https://client.camb.ai/apis"
|
||||
cambai_root = pathlib.Path("cambai")
|
||||
|
||||
def __init__(self, apikey: str) -> None:
|
||||
self.apikey = apikey
|
||||
|
||||
if not self.cambai_root.is_dir():
|
||||
self.cambai_root.mkdir()
|
||||
|
||||
@property
|
||||
def headers(self) -> dict[str, str]:
|
||||
return {"x-api-key": self.apikey}
|
||||
|
||||
def tts(self, text: str) -> Any:
|
||||
if (path := self.get_path(text)).exists():
|
||||
return path
|
||||
|
||||
task_id = self.gen_task(text)
|
||||
run_id = self.get_runid(task_id)
|
||||
return self.get_iostream(text, run_id)
|
||||
|
||||
def gen_task(self, text: str) -> str | None:
|
||||
tts_payload = {
|
||||
"text": text,
|
||||
"voice_id": 20299,
|
||||
"language": 1,
|
||||
"age": 30,
|
||||
"gender": 1,
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
f"{self.base_url}/tts", json=tts_payload, headers=self.headers
|
||||
)
|
||||
task_id = res.json().get("task_id")
|
||||
if not isinstance(task_id, str):
|
||||
logger.error(f"Got response {res.json()}")
|
||||
return None
|
||||
|
||||
return task_id
|
||||
|
||||
def get_runid(self, task_id: str | None) -> int | None:
|
||||
if task_id is None:
|
||||
return None
|
||||
|
||||
status = "PENDING"
|
||||
while status == "PENDING":
|
||||
res = requests.get(f"{self.base_url}/tts/{task_id}", headers=self.headers)
|
||||
status = res.json()["status"]
|
||||
print(f"Polling: {status}")
|
||||
time.sleep(1.5)
|
||||
|
||||
run_id = res.json().get("run_id")
|
||||
if not isinstance(run_id, int):
|
||||
return None
|
||||
|
||||
return run_id
|
||||
|
||||
def get_iostream(self, text: str, run_id: int | None) -> pathlib.Path | None:
|
||||
if run_id is None:
|
||||
return None
|
||||
|
||||
res = requests.get(
|
||||
f"{self.base_url}/tts-result/{run_id}", headers=self.headers, stream=True
|
||||
)
|
||||
|
||||
path = self.get_path(text)
|
||||
with open(path, "wb") as f:
|
||||
for chunk in res.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
return path
|
||||
|
||||
def get_name(self, text: str) -> str:
|
||||
return hashlib.sha256(text.encode()).hexdigest()
|
||||
|
||||
def get_path(self, text: str) -> pathlib.Path:
|
||||
return self.cambai_root.joinpath(f"{self.get_name(text)}.wav")
|
|
@ -9,6 +9,7 @@ dependencies = [
|
|||
"mistralai>=1.6.0",
|
||||
"py-cord>=2.6.1",
|
||||
"pynacl>=1.5.0",
|
||||
"requests>=2.32.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
48
uv.lock
generated
48
uv.lock
generated
|
@ -136,6 +136,7 @@ dependencies = [
|
|||
{ name = "mistralai" },
|
||||
{ name = "py-cord" },
|
||||
{ name = "pynacl" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
|
@ -152,6 +153,7 @@ requires-dist = [
|
|||
{ name = "mistralai", specifier = ">=1.6.0" },
|
||||
{ name = "py-cord", specifier = ">=2.6.1" },
|
||||
{ name = "pynacl", specifier = ">=1.5.0" },
|
||||
{ name = "requests", specifier = ">=2.32.3" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
|
@ -202,6 +204,28 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/16/b0/572805e227f01586461c80e0fd25d65a2115599cc9dad142fee4b747c357/charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3", size = 123188 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/38/94/ce8e6f63d18049672c76d07d119304e1e2d7c6098f0841b51c666e9f44a0/charset_normalizer-3.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda", size = 195698 },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/2e/dfdd9770664aae179a96561cc6952ff08f9a8cd09a908f259a9dfa063568/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313", size = 140162 },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/4e/f646b9093cff8fc86f2d60af2de4dc17c759de9d554f130b140ea4738ca6/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9", size = 150263 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/67/2937f8d548c3ef6e2f9aab0f6e21001056f692d43282b165e7c56023e6dd/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b", size = 142966 },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/ed/b7f4f07de100bdb95c1756d3a4d17b90c1a3c53715c1a476f8738058e0fa/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11", size = 144992 },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/2c/d49710a6dbcd3776265f4c923bb73ebe83933dfbaa841c5da850fe0fd20b/charset_normalizer-3.4.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f", size = 147162 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/41/35ff1f9a6bd380303dea55e44c4933b4cc3c4850988927d4082ada230273/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd", size = 140972 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/43/c6a0b685fe6910d08ba971f62cd9c3e862a85770395ba5d9cad4fede33ab/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2", size = 149095 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/ff/a9a504662452e2d2878512115638966e75633519ec11f25fca3d2049a94a/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886", size = 152668 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/71/189996b6d9a4b932564701628af5cee6716733e9165af1d5e1b285c530ed/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601", size = 150073 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/93/946a86ce20790e11312c87c75ba68d5f6ad2208cfb52b2d6a2c32840d922/charset_normalizer-3.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd", size = 145732 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/e5/131d2fb1b0dddafc37be4f3a2fa79aa4c037368be9423061dccadfd90091/charset_normalizer-3.4.1-cp313-cp313-win32.whl", hash = "sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407", size = 95391 },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/f2/4f9a69cc7712b9b5ad8fdb87039fd89abba997ad5cbe690d1835d40405b0/charset_normalizer-3.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971", size = 102702 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distlib"
|
||||
version = "0.3.9"
|
||||
|
@ -584,6 +608,21 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "idna" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.11.2"
|
||||
|
@ -648,6 +687,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.29.3"
|
||||
|
|
Loading…
Add table
Reference in a new issue