Refactor AIBot to use unified Mistral client and update message format

This commit is contained in:
Edgar P. Burkhart 2025-03-22 17:01:52 +01:00
parent 9448972a9a
commit 5c7a6f8ab0
Signed by: edpibu
GPG key ID: 9833D3C5A25BD227

View file

@ -1,14 +1,11 @@
from mistralai.async_client import MistralAsyncClient from mistralai import Mistral
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
class AIBot: class AIBot:
def __init__( def __init__(
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
): ):
self.client = MistralClient(api_key=api_key) self.client = Mistral(api_key=api_key)
self.async_client = MistralAsyncClient(api_key=api_key)
self.model = model self.model = model
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.system_message = system_message self.system_message = system_message
@ -16,7 +13,7 @@ class AIBot:
def get_responses(self, message): def get_responses(self, message):
return self.client.chat( return self.client.chat(
model=self.model, model=self.model,
messages=self.base_message + [ChatMessage(role="user", content=message)], messages=self.base_message + [{"role": "user", "content": message}],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@ -24,14 +21,14 @@ class AIBot:
return self.get_responses(message).choices[0].message.content return self.get_responses(message).choices[0].message.content
def get_response_stream(self, message): def get_response_stream(self, message):
return self.async_client.chat_stream( return self.client.chat.stream_async(
model=self.model, model=self.model,
messages=self.base_message + [ChatMessage(role="user", content=message)], messages=self.base_message + [{"role": "user", "content": message}],
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@property @property
def base_message(self): def base_message(self):
if self.system_message: if self.system_message:
return [ChatMessage(role="system", content=self.system_message)] return [{"role": "system", "content": self.system_message}]
return [] return []