Refactor AIBot to use unified Mistral client and update message format

This commit is contained in:
Edgar P. Burkhart 2025-03-22 17:01:52 +01:00
parent 9448972a9a
commit 5c7a6f8ab0
Signed by: edpibu
GPG key ID: 9833D3C5A25BD227

View file

@ -1,14 +1,11 @@
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from mistralai import Mistral
class AIBot:
def __init__(
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
):
self.client = MistralClient(api_key=api_key)
self.async_client = MistralAsyncClient(api_key=api_key)
self.client = Mistral(api_key=api_key)
self.model = model
self.max_tokens = max_tokens
self.system_message = system_message
@ -16,7 +13,7 @@ class AIBot:
def get_responses(self, message):
return self.client.chat(
model=self.model,
messages=self.base_message + [ChatMessage(role="user", content=message)],
messages=self.base_message + [{"role": "user", "content": message}],
max_tokens=self.max_tokens,
)
@ -24,14 +21,14 @@ class AIBot:
return self.get_responses(message).choices[0].message.content
def get_response_stream(self, message):
return self.async_client.chat_stream(
return self.client.chat.stream_async(
model=self.model,
messages=self.base_message + [ChatMessage(role="user", content=message)],
messages=self.base_message + [{"role": "user", "content": message}],
max_tokens=self.max_tokens,
)
@property
def base_message(self):
if self.system_message:
return [ChatMessage(role="system", content=self.system_message)]
return [{"role": "system", "content": self.system_message}]
return []