Refactor AIBot initialization and response handling; add type hints and update pre-commit configuration
This commit is contained in:
parent
7b010bfd0f
commit
761111bb07
6 changed files with 85 additions and 54 deletions
|
@ -19,3 +19,4 @@ repos:
|
||||||
rev: v1.15.0
|
rev: v1.15.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
|
additional_dependencies: [mistralai, py-cord]
|
||||||
|
|
0
botbotbot/__init__.py
Normal file
0
botbotbot/__init__.py
Normal file
|
@ -6,11 +6,11 @@ import tomllib
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .ai import AIBot
|
from botbotbot.ai import AIBot
|
||||||
|
|
||||||
description = """BotBotBot"""
|
description = """BotBotBot"""
|
||||||
|
|
||||||
logger = logging.getLogger("botbotbot")
|
logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
with open("config.toml", "rb") as config_file:
|
with open("config.toml", "rb") as config_file:
|
||||||
|
@ -26,11 +26,15 @@ Tu dois faire un commentaire pertinent en lien avec ce qui te sera dit.
|
||||||
Ta réponse doit être très courte.
|
Ta réponse doit être très courte.
|
||||||
Ta réponse doit être une seule phrase.
|
Ta réponse doit être une seule phrase.
|
||||||
TA RÉPONSE DOIT ÊTRE EN FRANÇAIS !!!"""
|
TA RÉPONSE DOIT ÊTRE EN FRANÇAIS !!!"""
|
||||||
aibot = AIBot(
|
|
||||||
config.get("mistral_api_key"),
|
aibot: AIBot | None = None
|
||||||
model="mistral-large-latest",
|
|
||||||
system_message=system_prompt,
|
if isinstance(key := config.get("mistral_api_key"), str):
|
||||||
)
|
aibot = AIBot(
|
||||||
|
key,
|
||||||
|
model="mistral-large-latest",
|
||||||
|
system_message=system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.members = True
|
intents.members = True
|
||||||
|
@ -45,11 +49,11 @@ shuffle_tasks = set()
|
||||||
|
|
||||||
|
|
||||||
@bot.listen("on_ready")
|
@bot.listen("on_ready")
|
||||||
async def on_ready():
|
async def on_ready() -> None:
|
||||||
logger.info(f"We have logged in as {bot.user}")
|
logger.info(f"We have logged in as {bot.user}")
|
||||||
|
|
||||||
|
|
||||||
async def reply(message):
|
async def reply(message: discord.Message) -> None:
|
||||||
logger.info(f"Reply to {message.author}")
|
logger.info(f"Reply to {message.author}")
|
||||||
mention = random.choices(
|
mention = random.choices(
|
||||||
[f"<@{message.author.id}>", "@everyone", "@here"], weights=(98, 1, 1)
|
[f"<@{message.author.id}>", "@everyone", "@here"], weights=(98, 1, 1)
|
||||||
|
@ -61,11 +65,11 @@ async def reply(message):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if random.random() < 0.1:
|
if isinstance(message.channel, discord.TextChannel) and random.random() < 0.1:
|
||||||
await send_as_webhook(
|
await send_as_webhook(
|
||||||
message.channel,
|
message.channel,
|
||||||
message.author.display_name,
|
message.author.display_name,
|
||||||
message.author.avatar.url,
|
message.author.avatar.url if message.author.avatar else None,
|
||||||
content,
|
content,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -74,12 +78,19 @@ async def reply(message):
|
||||||
await fct(content)
|
await fct(content)
|
||||||
|
|
||||||
|
|
||||||
async def ai_reply(message):
|
async def ai_reply(message: discord.Message) -> None:
|
||||||
|
if aibot is None:
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"AI Reply to {message.author}")
|
logger.info(f"AI Reply to {message.author}")
|
||||||
prompt = message.clean_content
|
prompt = message.clean_content
|
||||||
if prompt == "" and message.embeds:
|
if prompt == "" and message.embeds and message.embeds[0].description:
|
||||||
prompt = message.embeds[0].description
|
prompt = message.embeds[0].description
|
||||||
|
|
||||||
answer = aibot.answer(prompt)
|
answer = aibot.answer(prompt)
|
||||||
|
if not isinstance(answer, str):
|
||||||
|
return
|
||||||
|
|
||||||
if len(answer) > 2000:
|
if len(answer) > 2000:
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
description=answer,
|
description=answer,
|
||||||
|
@ -142,14 +153,14 @@ async def rando_shuffle(message: discord.Message) -> None:
|
||||||
await try_shuffle(message.guild)
|
await try_shuffle(message.guild)
|
||||||
|
|
||||||
|
|
||||||
def save_wordlist():
|
def save_wordlist() -> None:
|
||||||
logger.info("Saving updated wordlist")
|
logger.info("Saving updated wordlist")
|
||||||
with open("wordlist.pickle", "wb") as word_file:
|
with open("wordlist.pickle", "wb") as word_file:
|
||||||
pickle.dump(word_list, word_file)
|
pickle.dump(word_list, word_file)
|
||||||
|
|
||||||
|
|
||||||
@bot.slash_command(name="bibl", guild_ids=guild_ids, description="Ajouter une phrase")
|
@bot.slash_command(name="bibl", guild_ids=guild_ids, description="Ajouter une phrase")
|
||||||
async def bibl(ctx, phrase):
|
async def bibl(ctx: discord.ApplicationContext, phrase: str) -> None:
|
||||||
logger.info(f"BIBL {ctx.author} {phrase}")
|
logger.info(f"BIBL {ctx.author} {phrase}")
|
||||||
word_list.append(phrase)
|
word_list.append(phrase)
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
|
@ -161,7 +172,7 @@ async def bibl(ctx, phrase):
|
||||||
|
|
||||||
|
|
||||||
@bot.slash_command(name="tabl", guild_ids=guild_ids, description="Lister les phrases")
|
@bot.slash_command(name="tabl", guild_ids=guild_ids, description="Lister les phrases")
|
||||||
async def tabl(ctx):
|
async def tabl(ctx: discord.ApplicationContext) -> None:
|
||||||
logger.info(f"TABL {ctx.author}")
|
logger.info(f"TABL {ctx.author}")
|
||||||
embed = discord.Embed(
|
embed = discord.Embed(
|
||||||
title="TABL", description="\n".join(word_list), color=discord.Colour.green()
|
title="TABL", description="\n".join(word_list), color=discord.Colour.green()
|
||||||
|
@ -170,7 +181,7 @@ async def tabl(ctx):
|
||||||
|
|
||||||
|
|
||||||
@bot.slash_command(name="enle", guild_ids=guild_ids, description="Enlever une phrase")
|
@bot.slash_command(name="enle", guild_ids=guild_ids, description="Enlever une phrase")
|
||||||
async def enle(ctx, phrase):
|
async def enle(ctx: discord.ApplicationContext, phrase: str) -> None:
|
||||||
logger.info(f"ENLE {ctx.author} {phrase}")
|
logger.info(f"ENLE {ctx.author} {phrase}")
|
||||||
try:
|
try:
|
||||||
word_list.remove(phrase)
|
word_list.remove(phrase)
|
||||||
|
@ -189,7 +200,7 @@ async def enle(ctx, phrase):
|
||||||
logger.info("FIN ENLE")
|
logger.info("FIN ENLE")
|
||||||
|
|
||||||
|
|
||||||
async def try_shuffle(guild):
|
async def try_shuffle(guild: discord.Guild) -> bool:
|
||||||
if guild.id in shuffle_tasks:
|
if guild.id in shuffle_tasks:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -199,10 +210,11 @@ async def try_shuffle(guild):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def shuffle_nicks(guild):
|
async def shuffle_nicks(guild: discord.Guild) -> None:
|
||||||
logger.info("Shuffle")
|
logger.info("Shuffle")
|
||||||
members = guild.members
|
members = guild.members
|
||||||
members.remove(guild.owner)
|
if guild.owner:
|
||||||
|
members.remove(guild.owner)
|
||||||
|
|
||||||
nicks = [member.nick for member in members]
|
nicks = [member.nick for member in members]
|
||||||
|
|
||||||
|
@ -214,7 +226,7 @@ async def shuffle_nicks(guild):
|
||||||
|
|
||||||
|
|
||||||
@bot.slash_command(name="alea", guild_ids=guild_ids, description="Modifier les pseudos")
|
@bot.slash_command(name="alea", guild_ids=guild_ids, description="Modifier les pseudos")
|
||||||
async def alea(ctx):
|
async def alea(ctx: discord.ApplicationContext) -> None:
|
||||||
logger.info(f"ALEA {ctx.author}")
|
logger.info(f"ALEA {ctx.author}")
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
if await try_shuffle(ctx.guild):
|
if await try_shuffle(ctx.guild):
|
||||||
|
@ -228,7 +240,9 @@ async def alea(ctx):
|
||||||
|
|
||||||
|
|
||||||
@bot.listen("on_voice_state_update")
|
@bot.listen("on_voice_state_update")
|
||||||
async def voice_random_nicks(member, before, after):
|
async def voice_random_nicks(
|
||||||
|
member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
|
||||||
|
) -> None:
|
||||||
if before.channel is None and random.random() < 5 / 100:
|
if before.channel is None and random.random() < 5 / 100:
|
||||||
logger.info(f"Voice shuffle from {member}")
|
logger.info(f"Voice shuffle from {member}")
|
||||||
await try_shuffle(member.guild)
|
await try_shuffle(member.guild)
|
||||||
|
@ -248,7 +262,7 @@ async def voice_random_nicks(member, before, after):
|
||||||
source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg")
|
source = await discord.FFmpegOpusAudio.from_probe("assets/allo.ogg")
|
||||||
|
|
||||||
await asyncio.sleep(random.randrange(60))
|
await asyncio.sleep(random.randrange(60))
|
||||||
vo = await after.channel.connect()
|
vo: discord.VoiceClient = await after.channel.connect()
|
||||||
|
|
||||||
await asyncio.sleep(random.randrange(10))
|
await asyncio.sleep(random.randrange(10))
|
||||||
await vo.play(source, wait_finish=True)
|
await vo.play(source, wait_finish=True)
|
||||||
|
@ -261,7 +275,9 @@ async def voice_random_nicks(member, before, after):
|
||||||
@bot.slash_command(
|
@bot.slash_command(
|
||||||
name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI"
|
name="indu", guild_ids=guild_ids, description="Poser une question à MistralAI"
|
||||||
)
|
)
|
||||||
async def indu(ctx, prompt):
|
async def indu(ctx: discord.ApplicationContext, prompt: str) -> None:
|
||||||
|
if aibot is None:
|
||||||
|
return
|
||||||
logger.info(f"INDU {ctx.author} {prompt}")
|
logger.info(f"INDU {ctx.author} {prompt}")
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
res_stream = await aibot.get_response_stream(prompt)
|
res_stream = await aibot.get_response_stream(prompt)
|
||||||
|
@ -279,7 +295,7 @@ async def indu(ctx, prompt):
|
||||||
embed.description += chunk.data.choices[0].delta.content
|
embed.description += chunk.data.choices[0].delta.content
|
||||||
await message.edit(embed=embed)
|
await message.edit(embed=embed)
|
||||||
|
|
||||||
embed.color = None
|
embed.colour = None
|
||||||
await message.edit(embed=embed)
|
await message.edit(embed=embed)
|
||||||
logger.info("FIN INDU")
|
logger.info("FIN INDU")
|
||||||
|
|
||||||
|
@ -287,7 +303,7 @@ async def indu(ctx, prompt):
|
||||||
@bot.slash_command(
|
@bot.slash_command(
|
||||||
name="chan", guild_ids=guild_ids, description="Donner de nouveaux pseudos"
|
name="chan", guild_ids=guild_ids, description="Donner de nouveaux pseudos"
|
||||||
)
|
)
|
||||||
async def chan(ctx, file: discord.Attachment):
|
async def chan(ctx: discord.ApplicationContext, file: discord.Attachment) -> None:
|
||||||
logger.info(f"CHAN {ctx.author}")
|
logger.info(f"CHAN {ctx.author}")
|
||||||
await ctx.defer()
|
await ctx.defer()
|
||||||
|
|
||||||
|
@ -295,12 +311,12 @@ async def chan(ctx, file: discord.Attachment):
|
||||||
members.remove(ctx.guild.owner)
|
members.remove(ctx.guild.owner)
|
||||||
|
|
||||||
nicks = (await file.read()).decode().splitlines()
|
nicks = (await file.read()).decode().splitlines()
|
||||||
if len(nicks) != len(members):
|
if len(nicks) < len(members):
|
||||||
embed = discord.Embed(title="ERRE CHAN", color=discord.Colour.red())
|
embed = discord.Embed(title="ERRE CHAN", color=discord.Colour.red())
|
||||||
await ctx.respond(embed=embed)
|
await ctx.respond(embed=embed)
|
||||||
return
|
return
|
||||||
|
|
||||||
random.shuffle(nicks)
|
nicks = random.choices(nicks, k=len(members))
|
||||||
for member, nick in zip(members, nicks):
|
for member, nick in zip(members, nicks):
|
||||||
logger.info(member, nick)
|
logger.info(member, nick)
|
||||||
await member.edit(nick=nick)
|
await member.edit(nick=nick)
|
||||||
|
@ -315,19 +331,16 @@ async def chan(ctx, file: discord.Attachment):
|
||||||
async def send_as_webhook(
|
async def send_as_webhook(
|
||||||
channel: discord.TextChannel,
|
channel: discord.TextChannel,
|
||||||
name: str,
|
name: str,
|
||||||
avatar_url: str,
|
avatar_url: str | None,
|
||||||
content: str,
|
content: str,
|
||||||
embed: discord.Embed = None,
|
) -> None:
|
||||||
):
|
|
||||||
webhooks = await channel.webhooks()
|
webhooks = await channel.webhooks()
|
||||||
webhook = discord.utils.get(webhooks, name="BotbotbotHook")
|
webhook = discord.utils.get(webhooks, name="BotbotbotHook")
|
||||||
|
|
||||||
if webhook is None:
|
if webhook is None:
|
||||||
webhook = await channel.create_webhook(name="BotbotbotHook")
|
webhook = await channel.create_webhook(name="BotbotbotHook")
|
||||||
|
|
||||||
await webhook.send(
|
await webhook.send(content=content, username=name, avatar_url=avatar_url)
|
||||||
content=content, username=name, avatar_url=avatar_url, embed=embed
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
bot.run(config.get("token"))
|
bot.run(config.get("token"))
|
||||||
|
|
|
@ -1,34 +1,52 @@
|
||||||
from mistralai import Mistral
|
from typing import Any, Coroutine
|
||||||
|
|
||||||
|
import mistralai
|
||||||
|
from mistralai.utils import eventstreaming
|
||||||
|
|
||||||
|
|
||||||
class AIBot:
|
class AIBot:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
|
self,
|
||||||
):
|
api_key: str,
|
||||||
self.client = Mistral(api_key=api_key)
|
model: str = "open-mistral-7b",
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
system_message: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.client = mistralai.Mistral(api_key=api_key)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.system_message = system_message
|
self.system_message = system_message
|
||||||
|
|
||||||
def get_responses(self, message):
|
def get_responses(self, message: str) -> mistralai.ChatCompletionResponse:
|
||||||
return self.client.chat.complete(
|
return self.client.chat.complete(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.base_message + [{"role": "user", "content": message}],
|
messages=self.get_message(message),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def answer(self, message):
|
def answer(self, message: str) -> str | None:
|
||||||
return self.get_responses(message).choices[0].message.content
|
res = self.get_responses(message).choices
|
||||||
|
if not res:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_response_stream(self, message):
|
return str(res[0].message.content)
|
||||||
|
|
||||||
|
def get_response_stream(
|
||||||
|
self, message: str
|
||||||
|
) -> Coroutine[
|
||||||
|
Any, Any, eventstreaming.EventStreamAsync[mistralai.CompletionEvent]
|
||||||
|
]:
|
||||||
return self.client.chat.stream_async(
|
return self.client.chat.stream_async(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.base_message + [{"role": "user", "content": message}],
|
messages=self.get_message(message),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def get_message(self, content: str) -> list[Any]:
|
||||||
def base_message(self):
|
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
return [{"role": "system", "content": self.system_message}]
|
return [
|
||||||
return []
|
mistralai.SystemMessage(content=self.system_message),
|
||||||
|
mistralai.UserMessage(content=content),
|
||||||
|
]
|
||||||
|
|
||||||
|
return [mistralai.UserMessage(content=content)]
|
||||||
|
|
6
main.py
6
main.py
|
@ -1,6 +0,0 @@
|
||||||
def main():
|
|
||||||
print("Hello from botbotbot-py!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -18,3 +18,8 @@ dev = [
|
||||||
"pre-commit>=4.2.0",
|
"pre-commit>=4.2.0",
|
||||||
"ruff>=0.11.2",
|
"ruff>=0.11.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
strict = true
|
||||||
|
disallow_untyped_calls = false
|
||||||
|
disallow_untyped_decorators = false
|
||||||
|
|
Loading…
Add table
Reference in a new issue