|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from peft import PeftModel |
|
import random |
|
|
|
base_model_name = "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit" |
|
adapter_model_name = "christinashihan/Specificsituation_empathy_transform" |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True) |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
quantization_config=bnb_config |
|
) |
|
model = PeftModel.from_pretrained(base_model, adapter_model_name) |
|
model.eval() |
|
|
|
def build_prompt(message, history): |
|
prompt = ( |
|
"<|system|>You are a psychologically supportive assistant.\n" |
|
"Structure your reply into three parts: (1) Emotional validation, " |
|
"(2) Situation understanding, (3) Practical advice.\n" |
|
) |
|
for user, assistant in history: |
|
prompt += f"<|user|>{user.strip()}\n<|assistant|>{assistant.strip()}\n" |
|
prompt += f"<|user|>{message.strip()}\n<|assistant|>" |
|
return prompt |
|
|
|
def generate_reply(message, history): |
|
prompt = build_prompt(message, history) |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device) |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=256, |
|
temperature=0.5, |
|
top_p=0.85, |
|
repetition_penalty=1.3, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id, |
|
do_sample=True, |
|
early_stopping=True |
|
) |
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
if "<|assistant|>" in decoded: |
|
reply = decoded.split("<|assistant|>")[-1].split("<|user|>")[0].strip() |
|
else: |
|
reply = decoded.strip() |
|
lines = reply.split(". ") |
|
seen, deduped = set(), [] |
|
for line in lines: |
|
line = line.strip() |
|
if line and line not in seen: |
|
deduped.append(line) |
|
seen.add(line) |
|
return ". ".join(deduped).strip() |
|
|
|
|
|
initial_messages = [ |
|
"I'm here with you. Is there anything I can help you with?", |
|
"You can talk to me about anything that's on your mind.", |
|
"When you're feeling down, try taking a deep breath. I'm here to listen." |
|
] |
|
initial_message = random.choice(initial_messages) |
|
|
|
chat_ui = gr.ChatInterface( |
|
fn=generate_reply, |
|
chatbot=gr.Chatbot(value=[("", initial_message)]), |
|
title="π§ YourCloseFriend Empathy AI", |
|
description="A warm and emotionally supportive assistant fine-tuned by Christina Shihan.", |
|
theme="soft", |
|
examples=[ |
|
"I feel like nothing really matters anymore.", |
|
"I'm overwhelmed with anxiety.", |
|
"Can you give me advice on dealing with burnout?", |
|
"My boss yelled at me in front of everyone.", |
|
"I feel stuck in my life right now." |
|
] |
|
) |
|
|
|
chat_ui.launch() |