Refactor AIBot to use unified Mistral client and update message format
This commit is contained in:
parent
9448972a9a
commit
5c7a6f8ab0
1 changed files with 6 additions and 9 deletions
|
@ -1,14 +1,11 @@
|
|||
from mistralai.async_client import MistralAsyncClient
|
||||
from mistralai.client import MistralClient
|
||||
from mistralai.models.chat_completion import ChatMessage
|
||||
from mistralai import Mistral
|
||||
|
||||
|
||||
class AIBot:
|
||||
def __init__(
|
||||
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
|
||||
):
|
||||
self.client = MistralClient(api_key=api_key)
|
||||
self.async_client = MistralAsyncClient(api_key=api_key)
|
||||
self.client = Mistral(api_key=api_key)
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.system_message = system_message
|
||||
|
@ -16,7 +13,7 @@ class AIBot:
|
|||
def get_responses(self, message):
|
||||
return self.client.chat(
|
||||
model=self.model,
|
||||
messages=self.base_message + [ChatMessage(role="user", content=message)],
|
||||
messages=self.base_message + [{"role": "user", "content": message}],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
|
@ -24,14 +21,14 @@ class AIBot:
|
|||
return self.get_responses(message).choices[0].message.content
|
||||
|
||||
def get_response_stream(self, message):
|
||||
return self.async_client.chat_stream(
|
||||
return self.client.chat.stream_async(
|
||||
model=self.model,
|
||||
messages=self.base_message + [ChatMessage(role="user", content=message)],
|
||||
messages=self.base_message + [{"role": "user", "content": message}],
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
@property
|
||||
def base_message(self):
|
||||
if self.system_message:
|
||||
return [ChatMessage(role="system", content=self.system_message)]
|
||||
return [{"role": "system", "content": self.system_message}]
|
||||
return []
|
||||
|
|
Loading…
Add table
Reference in a new issue