from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
class ChatModel: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=True) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-v0.1", | |
torch_dtype=torch.float16, | |
token=True | |
) | |
async def generate_response(self, input_text): | |
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_length=100, | |
num_return_sequences=1, | |
temperature=0.7 | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response |