diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9313325..d59bbb6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,3 +19,4 @@ repos: rev: v1.15.0 hooks: - id: mypy + additional_dependencies: [mistralai, py-cord] diff --git a/botbotbot/__init__.py b/botbotbot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/botbotbot/__main__.py b/botbotbot/__main__.py index 9537f9f..f9c60cd 100644 --- a/botbotbot/__main__.py +++ b/botbotbot/__main__.py @@ -6,11 +6,11 @@ import tomllib import discord -from .ai import AIBot +from botbotbot.ai import AIBot description = """BotBotBot""" -logger = logging.getLogger("botbotbot") +logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) with open("config.toml", "rb") as config_file: @@ -26,11 +26,15 @@ Tu dois faire un commentaire pertinent en lien avec ce qui te sera dit. Ta réponse doit être très courte. Ta réponse doit être une seule phrase. TA RÉPONSE DOIT ÊTRE EN FRANÇAIS !!!""" -aibot = AIBot( - config.get("mistral_api_key"), - model="mistral-large-latest", - system_message=system_prompt, -) + +aibot: AIBot | None = None + +if isinstance(key := config.get("mistral_api_key"), str): + aibot = AIBot( + key, + model="mistral-large-latest", + system_message=system_prompt, + ) intents = discord.Intents.default() intents.members = True @@ -45,11 +49,11 @@ shuffle_tasks = set() @bot.listen("on_ready") -async def on_ready(): +async def on_ready() -> None: logger.info(f"We have logged in as {bot.user}") -async def reply(message): +async def reply(message: discord.Message) -> None: logger.info(f"Reply to {message.author}") mention = random.choices( [f"<@{message.author.id}>", "@everyone", "@here"], weights=(98, 1, 1) @@ -61,11 +65,11 @@ async def reply(message): ) ) - if random.random() < 0.1: + if isinstance(message.channel, discord.TextChannel) and random.random() < 0.1: await send_as_webhook( message.channel, message.author.display_name, - message.author.avatar.url, + message.author.avatar.url if message.author.avatar else None, content, ) else: @@ -74,12 +78,19 @@ async def reply(message): await fct(content) -async def ai_reply(message): +async def ai_reply(message: discord.Message) -> None: + if aibot is None: + return + logger.info(f"AI Reply to {message.author}") prompt = message.clean_content - if prompt == "" and message.embeds: + if prompt == "" and message.embeds and message.embeds[0].description: prompt = message.embeds[0].description + answer = aibot.answer(prompt) + if not isinstance(answer, str): + return + if len(answer) > 2000: embed = discord.Embed( description=answer, @@ -142,14 +153,14 @@ async def rando_shuffle(message: discord.Message) -> None: await try_shuffle(message.guild) -def save_wordlist(): +def save_wordlist() -> None: logger.info("Saving updated wordlist") with open("wordlist.pickle", "wb") as word_file: pickle.dump(word_list, word_file) @bot.slash_command(name="bibl", guild_ids=guild_ids, description="Ajouter une phrase") -async def bibl(ctx, phrase): +async def bibl(ctx: discord.ApplicationContext, phrase: str) -> None: logger.info(f"BIBL {ctx.author} {phrase}") word_list.append(phrase) embed = discord.Embed( @@ -161,7 +172,7 @@ async def bibl(ctx, phrase): @bot.slash_command(name="tabl", guild_ids=guild_ids, description="Lister les phrases") -async def tabl(ctx): +async def tabl(ctx: discord.ApplicationContext) -> None: logger.info(f"TABL {ctx.author}") embed = discord.Embed( title="TABL", description="\n".join(word_list), color=discord.Colour.green() @@ -170,7 +181,7 @@ async def tabl(ctx): @bot.slash_command(name="enle", guild_ids=guild_ids, description="Enlever une phrase") -async def enle(ctx, phrase): +async def enle(ctx: discord.ApplicationContext, phrase: str) -> None: logger.info(f"ENLE {ctx.author} {phrase}") try: word_list.remove(phrase) @@ -189,7 +200,7 @@ async def enle(ctx, phrase): logger.info("FIN ENLE") -async def try_shuffle(guild): +async def try_shuffle(guild: discord.Guild) -> bool: if guild.id in shuffle_tasks: return False @@ -199,10 +210,11 @@ async def try_shuffle(guild): return True -async def shuffle_nicks(guild): +async def shuffle_nicks(guild: discord.Guild) -> None: logger.info("Shuffle") members = guild.members - members.remove(guild.owner) + if guild.owner: + members.remove(guild.owner) nicks = [member.nick for member in members] @@ -214,7 +226,7 @@ async def shuffle_nicks(guild): @bot.slash_command(name="alea", guild_ids=guild_ids, description="Modifier les pseudos") -async def alea(ctx): +async def alea(ctx: discord.ApplicationContext) -> None: logger.info(f"ALEA {ctx.author}") await ctx.defer() if await try_shuffle(ctx.guild): @@ -228,7 +240,9 @@ async def alea(ctx): @bot.listen("on_voice_state_update") -async def voice_random_nicks(member, before, after): +async def voice_random_nicks( + member: discord.Member, before: discord.VoiceState, after: discord.VoiceState +) -> None: if before.channel is None and random.random() < 5 / 100: logger.info(f"Voice shuffle from {member}") await try_shuffle(member.guild) @@ -248,7 +262,7 @@ async def voice_random_nicks(member, before, after): source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg") await asyncio.sleep(random.randrange(60)) - vo = await after.channel.connect() + vo: discord.VoiceClient = await after.channel.connect() await asyncio.sleep(random.randrange(10)) await vo.play(source, wait_finish=True) @@ -261,7 +275,9 @@ async def voice_random_nicks(member, before, after): @bot.slash_command( name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI" ) -async def indu(ctx, prompt): +async def indu(ctx: discord.ApplicationContext, prompt: str) -> None: + if aibot is None: + return logger.info(f"INDU {ctx.author} {prompt}") await ctx.defer() res_stream = await aibot.get_response_stream(prompt) @@ -279,7 +295,7 @@ async def indu(ctx, prompt): embed.description += chunk.data.choices[0].delta.content await message.edit(embed=embed) - embed.color = None + embed.colour = None await message.edit(embed=embed) logger.info("FIN INDU") @@ -287,7 +303,7 @@ async def indu(ctx, prompt): @bot.slash_command( name="chan", guild_ids=guild_ids, description="Donner de nouveaux pseudos" ) -async def chan(ctx, file: discord.Attachment): +async def chan(ctx: discord.ApplicationContext, file: discord.Attachment) -> None: logger.info(f"CHAN {ctx.author}") await ctx.defer() @@ -295,12 +311,12 @@ async def chan(ctx, file: discord.Attachment): members.remove(ctx.guild.owner) nicks = (await file.read()).decode().splitlines() - if len(nicks) != len(members): + if len(nicks) < len(members): embed = discord.Embed(title="ERRE CHAN", color=discord.Colour.red()) await ctx.respond(embed=embed) return - random.shuffle(nicks) + nicks = random.choices(nicks, k=len(members)) for member, nick in zip(members, nicks): logger.info(member, nick) await member.edit(nick=nick) @@ -315,19 +331,16 @@ async def chan(ctx, file: discord.Attachment): async def send_as_webhook( channel: discord.TextChannel, name: str, - avatar_url: str, + avatar_url: str | None, content: str, - embed: discord.Embed = None, -): +) -> None: webhooks = await channel.webhooks() webhook = discord.utils.get(webhooks, name="BotbotbotHook") if webhook is None: webhook = await channel.create_webhook(name="BotbotbotHook") - await webhook.send( - content=content, username=name, avatar_url=avatar_url, embed=embed - ) + await webhook.send(content=content, username=name, avatar_url=avatar_url) bot.run(config.get("token")) diff --git a/botbotbot/ai.py b/botbotbot/ai.py index eeb9ab2..fcfe521 100644 --- a/botbotbot/ai.py +++ b/botbotbot/ai.py @@ -1,34 +1,52 @@ -from mistralai import Mistral +from typing import Any, Coroutine + +import mistralai +from mistralai.utils import eventstreaming class AIBot: def __init__( - self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None - ): - self.client = Mistral(api_key=api_key) + self, + api_key: str, + model: str = "open-mistral-7b", + max_tokens: int | None = None, + system_message: str | None = None, + ) -> None: + self.client = mistralai.Mistral(api_key=api_key) self.model = model self.max_tokens = max_tokens self.system_message = system_message - def get_responses(self, message): + def get_responses(self, message: str) -> mistralai.ChatCompletionResponse: return self.client.chat.complete( model=self.model, - messages=self.base_message + [{"role": "user", "content": message}], + messages=self.get_message(message), max_tokens=self.max_tokens, ) - def answer(self, message): - return self.get_responses(message).choices[0].message.content + def answer(self, message: str) -> str | None: + res = self.get_responses(message).choices + if not res: + return None - def get_response_stream(self, message): + return str(res[0].message.content) + + def get_response_stream( + self, message: str + ) -> Coroutine[ + Any, Any, eventstreaming.EventStreamAsync[mistralai.CompletionEvent] + ]: return self.client.chat.stream_async( model=self.model, - messages=self.base_message + [{"role": "user", "content": message}], + messages=self.get_message(message), max_tokens=self.max_tokens, ) - @property - def base_message(self): + def get_message(self, content: str) -> list[Any]: if self.system_message: - return [{"role": "system", "content": self.system_message}] - return [] + return [ + mistralai.SystemMessage(content=self.system_message), + mistralai.UserMessage(content=content), + ] + + return [mistralai.UserMessage(content=content)] diff --git a/main.py b/main.py deleted file mode 100644 index a50a11f..0000000 --- a/main.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from botbotbot-py!") - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index aeb5f38..4ac1eb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,3 +18,8 @@ dev = [ "pre-commit>=4.2.0", "ruff>=0.11.2", ] + +[tool.mypy] +strict = true +disallow_untyped_calls = false +disallow_untyped_decorators = false