|
import threading |
|
import time |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
import torch |
|
|
|
|
|
model_id = "lambdaindie/lambda-1v-1B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
css = """ |
|
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono&display=swap'); |
|
* { |
|
font-family: 'JetBrains Mono', monospace !important; |
|
} |
|
html, body, .gradio-container { |
|
background-color: #111 !important; |
|
color: #e0e0e0 !important; |
|
} |
|
textarea, input, button, select { |
|
background-color: transparent !important; |
|
color: #e0e0e0 !important; |
|
border: 1px solid #444 !important; |
|
} |
|
""" |
|
|
|
|
|
stop_signal = False |
|
|
|
def stop_stream(): |
|
global stop_signal |
|
stop_signal = True |
|
|
|
|
|
def generate_response(message, max_tokens, temperature, top_p): |
|
global stop_signal |
|
stop_signal = False |
|
|
|
prompt = f"Question: {message}\nThinking: \nAnswer:" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = dict( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
streamer=streamer, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
full_text = "" |
|
for token in streamer: |
|
if stop_signal: |
|
break |
|
full_text += token |
|
yield full_text.strip() |
|
|
|
if stop_signal: |
|
return |
|
|
|
|
|
with gr.Blocks(css=css) as app: |
|
chatbot = gr.Chatbot(label="λ", elem_id="chatbot") |
|
msg = gr.Textbox(label="Mensagem", placeholder="Digite aqui...", lines=2) |
|
send_btn = gr.Button("Enviar") |
|
stop_btn = gr.Button("Parar") |
|
|
|
max_tokens = gr.Slider(64, 512, value=128, step=1, label="Max Tokens") |
|
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature") |
|
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") |
|
|
|
state = gr.State([]) |
|
|
|
def update_chat(message, chat_history): |
|
chat_history = chat_history + [(message, None)] |
|
return "", chat_history |
|
|
|
def generate_full(chat_history, max_tokens, temperature, top_p): |
|
message = chat_history[-1][0] |
|
visual_history = chat_history[:-1] |
|
|
|
full_response = "" |
|
for chunk in generate_response(message, max_tokens, temperature, top_p): |
|
full_response = chunk |
|
yield visual_history + [(message, full_response)], visual_history + [(message, full_response)] |
|
|
|
send_btn.click(update_chat, inputs=[msg, state], outputs=[msg, state]) \ |
|
.then(generate_full, inputs=[state, max_tokens, temperature, top_p], outputs=[chatbot, state]) |
|
|
|
stop_btn.click(stop_stream, inputs=[], outputs=[]) |
|
|
|
app.launch(share=True) |