Compare commits

...

2 commits

6 changed files with 89 additions and 58 deletions

View file

@ -19,3 +19,4 @@ repos:
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies: [mistralai, py-cord]

0
botbotbot/__init__.py Normal file
View file

View file

@ -6,11 +6,11 @@ import tomllib
import discord
from .ai import AIBot
from botbotbot.ai import AIBot
description = """BotBotBot"""
logger = logging.getLogger("botbotbot")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
with open("config.toml", "rb") as config_file:
@ -26,11 +26,15 @@ Tu dois faire un commentaire pertinent en lien avec ce qui te sera dit.
Ta réponse doit être très courte.
Ta réponse doit être une seule phrase.
TA RÉPONSE DOIT ÊTRE EN FRANÇAIS !!!"""
aibot = AIBot(
config.get("mistral_api_key"),
model="mistral-large-latest",
system_message=system_prompt,
)
aibot: AIBot | None = None
if isinstance(key := config.get("mistral_api_key"), str):
aibot = AIBot(
key,
model="mistral-large-latest",
system_message=system_prompt,
)
intents = discord.Intents.default()
intents.members = True
@ -45,11 +49,11 @@ shuffle_tasks = set()
@bot.listen("on_ready")
async def on_ready():
async def on_ready() -> None:
logger.info(f"We have logged in as {bot.user}")
async def reply(message):
async def reply(message: discord.Message) -> None:
logger.info(f"Reply to {message.author}")
mention = random.choices(
[f"<@{message.author.id}>", "@everyone", "@here"], weights=(98, 1, 1)
@ -61,11 +65,11 @@ async def reply(message):
)
)
if random.random() < 0.1:
if isinstance(message.channel, discord.TextChannel) and random.random() < 0.1:
await send_as_webhook(
message.channel,
message.author.display_name,
message.author.avatar.url,
message.author.avatar.url if message.author.avatar else None,
content,
)
else:
@ -74,12 +78,19 @@ async def reply(message):
await fct(content)
async def ai_reply(message):
async def ai_reply(message: discord.Message) -> None:
if aibot is None:
return
logger.info(f"AI Reply to {message.author}")
prompt = message.clean_content
if prompt == "" and message.embeds:
if prompt == "" and message.embeds and message.embeds[0].description:
prompt = message.embeds[0].description
answer = aibot.answer(prompt)
if not isinstance(answer, str):
return
if len(answer) > 2000:
embed = discord.Embed(
description=answer,
@ -142,14 +153,14 @@ async def rando_shuffle(message: discord.Message) -> None:
await try_shuffle(message.guild)
def save_wordlist():
def save_wordlist() -> None:
logger.info("Saving updated wordlist")
with open("wordlist.pickle", "wb") as word_file:
pickle.dump(word_list, word_file)
@bot.slash_command(name="bibl", guild_ids=guild_ids, description="Ajouter une phrase")
async def bibl(ctx, phrase):
async def bibl(ctx: discord.ApplicationContext, phrase: str) -> None:
logger.info(f"BIBL {ctx.author} {phrase}")
word_list.append(phrase)
embed = discord.Embed(
@ -161,7 +172,7 @@ async def bibl(ctx, phrase):
@bot.slash_command(name="tabl", guild_ids=guild_ids, description="Lister les phrases")
async def tabl(ctx):
async def tabl(ctx: discord.ApplicationContext) -> None:
logger.info(f"TABL {ctx.author}")
embed = discord.Embed(
title="TABL", description="\n".join(word_list), color=discord.Colour.green()
@ -170,7 +181,7 @@ async def tabl(ctx):
@bot.slash_command(name="enle", guild_ids=guild_ids, description="Enlever une phrase")
async def enle(ctx, phrase):
async def enle(ctx: discord.ApplicationContext, phrase: str) -> None:
logger.info(f"ENLE {ctx.author} {phrase}")
try:
word_list.remove(phrase)
@ -189,7 +200,7 @@ async def enle(ctx, phrase):
logger.info("FIN ENLE")
async def try_shuffle(guild):
async def try_shuffle(guild: discord.Guild) -> bool:
if guild.id in shuffle_tasks:
return False
@ -199,10 +210,11 @@ async def try_shuffle(guild):
return True
async def shuffle_nicks(guild):
async def shuffle_nicks(guild: discord.Guild) -> None:
logger.info("Shuffle")
members = guild.members
members.remove(guild.owner)
if guild.owner:
members.remove(guild.owner)
nicks = [member.nick for member in members]
@ -214,7 +226,7 @@ async def shuffle_nicks(guild):
@bot.slash_command(name="alea", guild_ids=guild_ids, description="Modifier les pseudos")
async def alea(ctx):
async def alea(ctx: discord.ApplicationContext) -> None:
logger.info(f"ALEA {ctx.author}")
await ctx.defer()
if await try_shuffle(ctx.guild):
@ -228,7 +240,9 @@ async def alea(ctx):
@bot.listen("on_voice_state_update")
async def voice_random_nicks(member, before, after):
async def voice_random_nicks(
member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
) -> None:
if before.channel is None and random.random() < 5 / 100:
logger.info(f"Voice shuffle from {member}")
await try_shuffle(member.guild)
@ -248,7 +262,7 @@ async def voice_random_nicks(member, before, after):
source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg")
await asyncio.sleep(random.randrange(60))
vo = await after.channel.connect()
vo: discord.VoiceClient = await after.channel.connect()
await asyncio.sleep(random.randrange(10))
await vo.play(source, wait_finish=True)
@ -261,10 +275,12 @@ async def voice_random_nicks(member, before, after):
@bot.slash_command(
name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI"
)
async def indu(ctx, prompt):
async def indu(ctx: discord.ApplicationContext, prompt: str) -> None:
if aibot is None:
return
logger.info(f"INDU {ctx.author} {prompt}")
await ctx.defer()
res_stream = aibot.get_response_stream(prompt)
res_stream = await aibot.get_response_stream(prompt)
embed = discord.Embed(
title=prompt,
@ -275,11 +291,11 @@ async def indu(ctx, prompt):
message = await ctx.respond(embed=embed)
async for chunk in res_stream:
if chunk.choices[0].delta.content is not None:
embed.description += chunk.choices[0].delta.content
if chunk.data.choices[0].delta.content is not None:
embed.description += chunk.data.choices[0].delta.content
await message.edit(embed=embed)
embed.color = None
embed.colour = None
await message.edit(embed=embed)
logger.info("FIN INDU")
@ -287,7 +303,7 @@ async def indu(ctx, prompt):
@bot.slash_command(
name="chan", guild_ids=guild_ids, description="Donner de nouveaux pseudos"
)
async def chan(ctx, file: discord.Attachment):
async def chan(ctx: discord.ApplicationContext, file: discord.Attachment) -> None:
logger.info(f"CHAN {ctx.author}")
await ctx.defer()
@ -295,12 +311,12 @@ async def chan(ctx, file: discord.Attachment):
members.remove(ctx.guild.owner)
nicks = (await file.read()).decode().splitlines()
if len(nicks) != len(members):
if len(nicks) < len(members):
embed = discord.Embed(title="ERRE CHAN", color=discord.Colour.red())
await ctx.respond(embed=embed)
return
random.shuffle(nicks)
nicks = random.choices(nicks, k=len(members))
for member, nick in zip(members, nicks):
logger.info(member, nick)
await member.edit(nick=nick)
@ -315,19 +331,16 @@ async def chan(ctx, file: discord.Attachment):
async def send_as_webhook(
channel: discord.TextChannel,
name: str,
avatar_url: str,
avatar_url: str | None,
content: str,
embed: discord.Embed = None,
):
) -> None:
webhooks = await channel.webhooks()
webhook = discord.utils.get(webhooks, name="BotbotbotHook")
if webhook is None:
webhook = await channel.create_webhook(name="BotbotbotHook")
await webhook.send(
content=content, username=name, avatar_url=avatar_url, embed=embed
)
await webhook.send(content=content, username=name, avatar_url=avatar_url)
bot.run(config.get("token"))

View file

@ -1,34 +1,52 @@
from mistralai import Mistral
from typing import Any, Coroutine
import mistralai
from mistralai.utils import eventstreaming
class AIBot:
def __init__(
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
):
self.client = Mistral(api_key=api_key)
self,
api_key: str,
model: str = "open-mistral-7b",
max_tokens: int | None = None,
system_message: str | None = None,
) -> None:
self.client = mistralai.Mistral(api_key=api_key)
self.model = model
self.max_tokens = max_tokens
self.system_message = system_message
def get_responses(self, message):
return self.client.chat(
def get_responses(self, message: str) -> mistralai.ChatCompletionResponse:
return self.client.chat.complete(
model=self.model,
messages=self.base_message + [{"role": "user", "content": message}],
messages=self.get_message(message),
max_tokens=self.max_tokens,
)
def answer(self, message):
return self.get_responses(message).choices[0].message.content
def answer(self, message: str) -> str | None:
res = self.get_responses(message).choices
if not res:
return None
def get_response_stream(self, message):
return str(res[0].message.content)
def get_response_stream(
self, message: str
) -> Coroutine[
Any, Any, eventstreaming.EventStreamAsync[mistralai.CompletionEvent]
]:
return self.client.chat.stream_async(
model=self.model,
messages=self.base_message + [{"role": "user", "content": message}],
messages=self.get_message(message),
max_tokens=self.max_tokens,
)
@property
def base_message(self):
def get_message(self, content: str) -> list[Any]:
if self.system_message:
return [{"role": "system", "content": self.system_message}]
return []
return [
mistralai.SystemMessage(content=self.system_message),
mistralai.UserMessage(content=content),
]
return [mistralai.UserMessage(content=content)]

View file

@ -1,6 +0,0 @@
def main():
print("Hello from botbotbot-py!")
if __name__ == "__main__":
main()

View file

@ -18,3 +18,8 @@ dev = [
"pre-commit>=4.2.0",
"ruff>=0.11.2",
]
[tool.mypy]
strict = true
disallow_untyped_calls = false
disallow_untyped_decorators = false