Update AI with system prompt
This commit is contained in:
parent
130013f1e1
commit
a331d90b5b
2 changed files with 41 additions and 6 deletions
|
@ -16,7 +16,14 @@ with open("wordlist.pickle", "rb") as word_file:
|
|||
guild_ids = config.get("guild_ids")
|
||||
delay = config.get("delay", 60)
|
||||
|
||||
aibot = AIBot(config.get("mistral_api_key"), model="open-mixtral-8x7b")
|
||||
system_prompt = """Tu es une intelligence artificelle qui répond en français.
|
||||
Ta réponse doit être très courte.
|
||||
Ta réponse doit être longue d'une phrase."""
|
||||
aibot = AIBot(
|
||||
config.get("mistral_api_key"),
|
||||
model="open-mixtral-8x7b",
|
||||
system_message=system_prompt,
|
||||
)
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.members = True
|
||||
|
@ -163,13 +170,23 @@ async def alea(ctx):
|
|||
)
|
||||
async def indu(ctx, prompt):
|
||||
await ctx.defer()
|
||||
answer = aibot.answer(prompt)
|
||||
res_stream = aibot.get_response_stream(prompt)
|
||||
|
||||
embed = discord.Embed(
|
||||
title=prompt,
|
||||
description=answer,
|
||||
description="",
|
||||
thumbnail="https://mistral.ai/images/favicon/favicon-32x32.png",
|
||||
color=discord.Colour.orange(),
|
||||
)
|
||||
await ctx.respond(embed=embed)
|
||||
message = await ctx.respond(embed=embed)
|
||||
|
||||
async for chunk in res_stream:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
embed.description += chunk.choices[0].delta.content
|
||||
await message.edit(embed=embed)
|
||||
|
||||
embed.color = None
|
||||
await message.edit(embed=embed)
|
||||
|
||||
|
||||
@bot.slash_command(
|
||||
|
|
|
@ -1,19 +1,37 @@
|
|||
from mistralai.async_client import MistralAsyncClient
|
||||
from mistralai.client import MistralClient
|
||||
from mistralai.models.chat_completion import ChatMessage
|
||||
|
||||
|
||||
class AIBot:
|
||||
def __init__(self, api_key, model="open-mistral-7b", max_tokens=None):
|
||||
def __init__(
|
||||
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
|
||||
):
|
||||
self.client = MistralClient(api_key=api_key)
|
||||
self.async_client = MistralAsyncClient(api_key=api_key)
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.system_message = system_message
|
||||
|
||||
def get_responses(self, message):
|
||||
return self.client.chat(
|
||||
model=self.model,
|
||||
messages=[ChatMessage(role="user", content=message)],
|
||||
messages=self.base_message + [ChatMessage(role="user", content=message)],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
def answer(self, message):
|
||||
return self.get_responses(message).choices[0].message.content
|
||||
|
||||
def get_response_stream(self, message):
|
||||
return self.async_client.chat_stream(
|
||||
model=self.model,
|
||||
messages=self.base_message + [ChatMessage(role="user", content=message)],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@property
|
||||
def base_message(self):
|
||||
if self.system_message:
|
||||
return [ChatMessage(role="system", content=self.system_message)]
|
||||
return []
|
||||
|
|
Loading…
Reference in a new issue