52 lines
1.5 KiB
Python
52 lines
1.5 KiB
Python
from typing import Any, Coroutine
|
|
|
|
import mistralai
|
|
from mistralai.utils import eventstreaming
|
|
|
|
|
|
class AIBot:
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
model: str = "open-mistral-7b",
|
|
max_tokens: int | None = None,
|
|
system_message: str | None = None,
|
|
) -> None:
|
|
self.client = mistralai.Mistral(api_key=api_key)
|
|
self.model = model
|
|
self.max_tokens = max_tokens
|
|
self.system_message = system_message
|
|
|
|
def get_responses(self, message: str) -> mistralai.ChatCompletionResponse:
|
|
return self.client.chat.complete(
|
|
model=self.model,
|
|
messages=self.get_message(message),
|
|
max_tokens=self.max_tokens,
|
|
)
|
|
|
|
def answer(self, message: str) -> str | None:
|
|
res = self.get_responses(message).choices
|
|
if not res:
|
|
return None
|
|
|
|
return str(res[0].message.content)
|
|
|
|
def get_response_stream(
|
|
self, message: str
|
|
) -> Coroutine[
|
|
Any, Any, eventstreaming.EventStreamAsync[mistralai.CompletionEvent]
|
|
]:
|
|
return self.client.chat.stream_async(
|
|
model=self.model,
|
|
messages=self.get_message(message),
|
|
max_tokens=self.max_tokens,
|
|
)
|
|
|
|
def get_message(self, content: str) -> list[Any]:
|
|
if self.system_message:
|
|
return [
|
|
mistralai.SystemMessage(content=self.system_message),
|
|
mistralai.UserMessage(content=content),
|
|
]
|
|
|
|
return [mistralai.UserMessage(content=content)]
|