Refactor AIBot to use unified Mistral client and update message format
This commit is contained in:
parent
9448972a9a
commit
5c7a6f8ab0
1 changed files with 6 additions and 9 deletions
|
@ -1,14 +1,11 @@
|
||||||
from mistralai.async_client import MistralAsyncClient
|
from mistralai import Mistral
|
||||||
from mistralai.client import MistralClient
|
|
||||||
from mistralai.models.chat_completion import ChatMessage
|
|
||||||
|
|
||||||
|
|
||||||
class AIBot:
|
class AIBot:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
|
self, api_key, model="open-mistral-7b", max_tokens=None, system_message=None
|
||||||
):
|
):
|
||||||
self.client = MistralClient(api_key=api_key)
|
self.client = Mistral(api_key=api_key)
|
||||||
self.async_client = MistralAsyncClient(api_key=api_key)
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.system_message = system_message
|
self.system_message = system_message
|
||||||
|
@ -16,7 +13,7 @@ class AIBot:
|
||||||
def get_responses(self, message):
|
def get_responses(self, message):
|
||||||
return self.client.chat(
|
return self.client.chat(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.base_message + [ChatMessage(role="user", content=message)],
|
messages=self.base_message + [{"role": "user", "content": message}],
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,14 +21,14 @@ class AIBot:
|
||||||
return self.get_responses(message).choices[0].message.content
|
return self.get_responses(message).choices[0].message.content
|
||||||
|
|
||||||
def get_response_stream(self, message):
|
def get_response_stream(self, message):
|
||||||
return self.async_client.chat_stream(
|
return self.client.chat.stream_async(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.base_message + [ChatMessage(role="user", content=message)],
|
messages=self.base_message + [{"role": "user", "content": message}],
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_message(self):
|
def base_message(self):
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
return [ChatMessage(role="system", content=self.system_message)]
|
return [{"role": "system", "content": self.system_message}]
|
||||||
return []
|
return []
|
||||||
|
|
Loading…
Add table
Reference in a new issue