Refactor AIBot initialization and response handling; add type hints and update pre-commit configuration

This commit is contained in:
Edgar P. Burkhart 2025-03-22 18:05:42 +01:00
parent 7b010bfd0f
commit 761111bb07
Signed by: edpibu
GPG key ID: 9833D3C5A25BD227
6 changed files with 85 additions and 54 deletions

View file

@ -19,3 +19,4 @@ repos:
rev: v1.15.0 rev: v1.15.0
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: [mistralai, py-cord]

0
botbotbot/__init__.py Normal file
View file

View file

@ -6,11 +6,11 @@ import tomllib
import discord import discord
from .ai import AIBot from botbotbot.ai import AIBot
description = """BotBotBot""" description = """BotBotBot"""
logger = logging.getLogger("botbotbot") logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
with open("config.toml", "rb") as config_file: with open("config.toml", "rb") as config_file:
@ -26,11 +26,15 @@ Tu dois faire un commentaire pertinent en lien avec ce qui te sera dit.
Ta réponse doit être très courte. Ta réponse doit être très courte.
Ta réponse doit être une seule phrase. Ta réponse doit être une seule phrase.
TA RÉPONSE DOIT ÊTRE EN FRANÇAIS !!!""" TA RÉPONSE DOIT ÊTRE EN FRANÇAIS !!!"""
aibot = AIBot(
config.get("mistral_api_key"), aibot: AIBot | None = None
model="mistral-large-latest",
system_message=system_prompt, if isinstance(key := config.get("mistral_api_key"), str):
) aibot = AIBot(
key,
model="mistral-large-latest",
system_message=system_prompt,
)
intents = discord.Intents.default() intents = discord.Intents.default()
intents.members = True intents.members = True
@ -45,11 +49,11 @@ shuffle_tasks = set()
@bot.listen("on_ready") @bot.listen("on_ready")
async def on_ready(): async def on_ready() -> None:
logger.info(f"We have logged in as {bot.user}") logger.info(f"We have logged in as {bot.user}")
async def reply(message): async def reply(message: discord.Message) -> None:
logger.info(f"Reply to {message.author}") logger.info(f"Reply to {message.author}")
mention = random.choices( mention = random.choices(
[f"<@{message.author.id}>", "@everyone", "@here"], weights=(98, 1, 1) [f"<@{message.author.id}>", "@everyone", "@here"], weights=(98, 1, 1)
@ -61,11 +65,11 @@ async def reply(message):
) )
) )
if random.random() < 0.1: if isinstance(message.channel, discord.TextChannel) and random.random() < 0.1:
await send_as_webhook( await send_as_webhook(
message.channel, message.channel,
message.author.display_name, message.author.display_name,
message.author.avatar.url, message.author.avatar.url if message.author.avatar else None,
content, content,
) )
else: else:
@ -74,12 +78,19 @@ async def reply(message):
await fct(content) await fct(content)
async def ai_reply(message): async def ai_reply(message: discord.Message) -> None:
if aibot is None:
return
logger.info(f"AI Reply to {message.author}") logger.info(f"AI Reply to {message.author}")
prompt = message.clean_content prompt = message.clean_content
if prompt == "" and message.embeds: if prompt == "" and message.embeds and message.embeds[0].description:
prompt = message.embeds[0].description prompt = message.embeds[0].description
answer = aibot.answer(prompt) answer = aibot.answer(prompt)
if not isinstance(answer, str):
return
if len(answer) > 2000: if len(answer) > 2000:
embed = discord.Embed( embed = discord.Embed(
description=answer, description=answer,
@ -142,14 +153,14 @@ async def rando_shuffle(message: discord.Message) -> None:
await try_shuffle(message.guild) await try_shuffle(message.guild)
def save_wordlist(): def save_wordlist() -> None:
logger.info("Saving updated wordlist") logger.info("Saving updated wordlist")
with open("wordlist.pickle", "wb") as word_file: with open("wordlist.pickle", "wb") as word_file:
pickle.dump(word_list, word_file) pickle.dump(word_list, word_file)
@bot.slash_command(name="bibl", guild_ids=guild_ids, description="Ajouter une phrase") @bot.slash_command(name="bibl", guild_ids=guild_ids, description="Ajouter une phrase")
async def bibl(ctx, phrase): async def bibl(ctx: discord.ApplicationContext, phrase: str) -> None:
logger.info(f"BIBL {ctx.author} {phrase}") logger.info(f"BIBL {ctx.author} {phrase}")
word_list.append(phrase) word_list.append(phrase)
embed = discord.Embed( embed = discord.Embed(
@ -161,7 +172,7 @@ async def bibl(ctx, phrase):
@bot.slash_command(name="tabl", guild_ids=guild_ids, description="Lister les phrases") @bot.slash_command(name="tabl", guild_ids=guild_ids, description="Lister les phrases")
async def tabl(ctx): async def tabl(ctx: discord.ApplicationContext) -> None:
logger.info(f"TABL {ctx.author}") logger.info(f"TABL {ctx.author}")
embed = discord.Embed( embed = discord.Embed(
title="TABL", description="\n".join(word_list), color=discord.Colour.green() title="TABL", description="\n".join(word_list), color=discord.Colour.green()
@ -170,7 +181,7 @@ async def tabl(ctx):
@bot.slash_command(name="enle", guild_ids=guild_ids, description="Enlever une phrase") @bot.slash_command(name="enle", guild_ids=guild_ids, description="Enlever une phrase")
async def enle(ctx, phrase): async def enle(ctx: discord.ApplicationContext, phrase: str) -> None:
logger.info(f"ENLE {ctx.author} {phrase}") logger.info(f"ENLE {ctx.author} {phrase}")
try: try:
word_list.remove(phrase) word_list.remove(phrase)
@ -189,7 +200,7 @@ async def enle(ctx, phrase):
logger.info("FIN ENLE") logger.info("FIN ENLE")
async def try_shuffle(guild): async def try_shuffle(guild: discord.Guild) -> bool:
if guild.id in shuffle_tasks: if guild.id in shuffle_tasks:
return False return False
@ -199,10 +210,11 @@ async def try_shuffle(guild):
return True return True
async def shuffle_nicks(guild): async def shuffle_nicks(guild: discord.Guild) -> None:
logger.info("Shuffle") logger.info("Shuffle")
members = guild.members members = guild.members
members.remove(guild.owner) if guild.owner:
members.remove(guild.owner)
nicks = [member.nick for member in members] nicks = [member.nick for member in members]
@ -214,7 +226,7 @@ async def shuffle_nicks(guild):
@bot.slash_command(name="alea", guild_ids=guild_ids, description="Modifier les pseudos") @bot.slash_command(name="alea", guild_ids=guild_ids, description="Modifier les pseudos")
async def alea(ctx): async def alea(ctx: discord.ApplicationContext) -> None:
logger.info(f"ALEA {ctx.author}") logger.info(f"ALEA {ctx.author}")
await ctx.defer() await ctx.defer()
if await try_shuffle(ctx.guild): if await try_shuffle(ctx.guild):
@ -228,7 +240,9 @@ async def alea(ctx):
@bot.listen("on_voice_state_update") @bot.listen("on_voice_state_update")
async def voice_random_nicks(member, before, after): async def voice_random_nicks(
member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
) -> None:
if before.channel is None and random.random() < 5 / 100: if before.channel is None and random.random() < 5 / 100:
logger.info(f"Voice shuffle from {member}") logger.info(f"Voice shuffle from {member}")
await try_shuffle(member.guild) await try_shuffle(member.guild)
@ -248,7 +262,7 @@ async def voice_random_nicks(member, before, after):
source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg") source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg")
await asyncio.sleep(random.randrange(60)) await asyncio.sleep(random.randrange(60))
vo = await after.channel.connect() vo: discord.VoiceClient = await after.channel.connect()
await asyncio.sleep(random.randrange(10)) await asyncio.sleep(random.randrange(10))
await vo.play(source, wait_finish=True) await vo.play(source, wait_finish=True)
@ -261,7 +275,9 @@ async def voice_random_nicks(member, before, after):
@bot.slash_command( @bot.slash_command(
name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI" name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI"
) )
async def indu(ctx, prompt): async def indu(ctx: discord.ApplicationContext, prompt: str) -> None:
if aibot is None:
return
logger.info(f"INDU {ctx.author} {prompt}") logger.info(f"INDU {ctx.author} {prompt}")
await ctx.defer() await ctx.defer()
res_stream = await aibot.get_response_stream(prompt) res_stream = await aibot.get_response_stream(prompt)
@ -279,7 +295,7 @@ async def indu(ctx, prompt):
embed.description += chunk.data.choices[0].delta.content embed.description += chunk.data.choices[0].delta.content
await message.edit(embed=embed) await message.edit(embed=embed)
embed.color = None embed.colour = None
await message.edit(embed=embed) await message.edit(embed=embed)
logger.info("FIN INDU") logger.info("FIN INDU")
@ -287,7 +303,7 @@ async def indu(ctx, prompt):
@bot.slash_command( @bot.slash_command(
name="chan", guild_ids=guild_ids, description="Donner de nouveaux pseudos" name="chan", guild_ids=guild_ids, description="Donner de nouveaux pseudos"
) )
async def chan(ctx, file: discord.Attachment): async def chan(ctx: discord.ApplicationContext, file: discord.Attachment) -> None:
logger.info(f"CHAN {ctx.author}") logger.info(f"CHAN {ctx.author}")
await ctx.defer() await ctx.defer()
@ -295,12 +311,12 @@ async def chan(ctx, file: discord.Attachment):
members.remove(ctx.guild.owner) members.remove(ctx.guild.owner)
nicks = (await file.read()).decode().splitlines() nicks = (await file.read()).decode().splitlines()
if len(nicks) != len(members): if len(nicks) < len(members):
embed = discord.Embed(title="ERRE CHAN", color=discord.Colour.red()) embed = discord.Embed(title="ERRE CHAN", color=discord.Colour.red())
await ctx.respond(embed=embed) await ctx.respond(embed=embed)
return return
random.shuffle(nicks) nicks = random.choices(nicks, k=len(members))
for member, nick in zip(members, nicks): for member, nick in zip(members, nicks):
logger.info(member, nick) logger.info(member, nick)
await member.edit(nick=nick) await member.edit(nick=nick)
@ -315,19 +331,16 @@ async def chan(ctx, file: discord.Attachment):
async def send_as_webhook( async def send_as_webhook(
channel: discord.TextChannel, channel: discord.TextChannel,
name: str, name: str,
avatar_url: str, avatar_url: str | None,
content: str, content: str,
embed: discord.Embed = None, ) -> None:
):
webhooks = await channel.webhooks() webhooks = await channel.webhooks()
webhook = discord.utils.get(webhooks, name="BotbotbotHook") webhook = discord.utils.get(webhooks, name="BotbotbotHook")
if webhook is None: if webhook is None:
webhook = await channel.create_webhook(name="BotbotbotHook") webhook = await channel.create_webhook(name="BotbotbotHook")
await webhook.send( await webhook.send(content=content, username=name, avatar_url=avatar_url)
content=content, username=name, avatar_url=avatar_url, embed=embed
)
bot.run(config.get("token")) bot.run(config.get("token"))

View file

@ -1,34 +1,52 @@
from mistralai import Mistral from typing import Any, Coroutine
import mistralai
from mistralai.utils import eventstreaming
class AIBot: class AIBot:
def __init__( def __init__(
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None self,
): api_key: str,
self.client = Mistral(api_key=api_key) model: str = "open-mistral-7b",
max_tokens: int | None = None,
system_message: str | None = None,
) -> None:
self.client = mistralai.Mistral(api_key=api_key)
self.model = model self.model = model
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.system_message = system_message self.system_message = system_message
def get_responses(self, message): def get_responses(self, message: str) -> mistralai.ChatCompletionResponse:
return self.client.chat.complete( return self.client.chat.complete(
model=self.model, model=self.model,
messages=self.base_message + [{"role": "user", "content": message}], messages=self.get_message(message),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
def answer(self, message): def answer(self, message: str) -> str | None:
return self.get_responses(message).choices[0].message.content res = self.get_responses(message).choices
if not res:
return None
def get_response_stream(self, message): return str(res[0].message.content)
def get_response_stream(
self, message: str
) -> Coroutine[
Any, Any, eventstreaming.EventStreamAsync[mistralai.CompletionEvent]
]:
return self.client.chat.stream_async( return self.client.chat.stream_async(
model=self.model, model=self.model,
messages=self.base_message + [{"role": "user", "content": message}], messages=self.get_message(message),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@property def get_message(self, content: str) -> list[Any]:
def base_message(self):
if self.system_message: if self.system_message:
return [{"role": "system", "content": self.system_message}] return [
return [] mistralai.SystemMessage(content=self.system_message),
mistralai.UserMessage(content=content),
]
return [mistralai.UserMessage(content=content)]

View file

@ -1,6 +0,0 @@
def main():
print("Hello from botbotbot-py!")
if __name__ == "__main__":
main()

View file

@ -18,3 +18,8 @@ dev = [
"pre-commit>=4.2.0", "pre-commit>=4.2.0",
"ruff>=0.11.2", "ruff>=0.11.2",
] ]
[tool.mypy]
strict = true
disallow_untyped_calls = false
disallow_untyped_decorators = false