Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from transformers import pipeline | |
from huggingface_hub import InferenceClient | |
import os | |
system_messages = { "STRICT": "You are a chatbot evaluating github repositories, their python codes and corresponding readme files. Strictly answer the questions with Yes or No.", | |
"HELP": "You are a chatbot evaluating github repositories, their python codes and corresponding readme files. Please help me answer the following question. Keep your answers short, and informative." } | |
class LocalLLM(): | |
def __init__(self, model_name): | |
self.pipe = pipeline("text-generation", model=model_name, max_new_tokens=1000, device_map={0: 0}) | |
def predict(self, response_type, prompt): | |
messages = [ | |
{"role": "system", "content": system_messages[response_type]}, | |
{"role": "user", "content": prompt}, | |
] | |
res = self.pipe(messages) | |
res = res[0]["generated_text"] | |
res = [response for response in res if response["role"] == "assistant"][0]["content"] | |
res = res.strip() | |
return res | |
class RemoteLLM(): | |
def __init__(self, model_name): | |
token = os.getenv("hfToken") | |
self.model_name = model_name | |
self.client = InferenceClient(api_key=token) | |
def predict(self, response_type, prompt): | |
message = self.client.chat_completion( | |
model=self.model_name, max_tokens=500, stream=False, | |
messages=[{"role": "system", "content": system_messages[response_type]}, | |
{"role": "user", "content": prompt}]) | |
return message['choices'][0]['message']['content'] |