diff --git a/botbotbot/ai.py b/botbotbot/ai.py index 764596b..31c5538 100644 --- a/botbotbot/ai.py +++ b/botbotbot/ai.py @@ -1,14 +1,11 @@ -from mistralai.async_client import MistralAsyncClient -from mistralai.client import MistralClient -from mistralai.models.chat_completion import ChatMessage +from mistralai import Mistral class AIBot: def __init__( self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None ): - self.client = MistralClient(api_key=api_key) - self.async_client = MistralAsyncClient(api_key=api_key) + self.client = Mistral(api_key=api_key) self.model = model self.max_tokens = max_tokens self.system_message = system_message @@ -16,7 +13,7 @@ class AIBot: def get_responses(self, message): return self.client.chat( model=self.model, - messages=self.base_message + [ChatMessage(role="user", content=message)], + messages=self.base_message + [{"role": "user", "content": message}], max_tokens=self.max_tokens, ) @@ -24,14 +21,14 @@ class AIBot: return self.get_responses(message).choices[0].message.content def get_response_stream(self, message): - return self.async_client.chat_stream( + return self.client.chat.stream_async( model=self.model, - messages=self.base_message + [ChatMessage(role="user", content=message)], + messages=self.base_message + [{"role": "user", "content": message}], max_tokens=self.max_tokens, ) @property def base_message(self): if self.system_message: - return [ChatMessage(role="system", content=self.system_message)] + return [{"role": "system", "content": self.system_message}] return []