|
import json |
|
import time |
|
import ast |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from gradio_consilium_roundtable import consilium_roundtable |
|
|
|
|
|
MODEL_NAME = "katanemo/Arch-Router-1.5B" |
|
ARCH_ROUTER = "Arch Router" |
|
WAIT_DEPARTMENT = 5 |
|
WAIT_SYSTEM = 5 |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
route_config = [ |
|
{"name": "code_generation", "description": "Generating code based on prompts"}, |
|
{"name": "bug_fixing", "description": "Fixing errors or bugs in code"}, |
|
{"name": "performance_optimization", "description": "Improving code performance"}, |
|
{"name": "api_help", "description": "Assisting with APIs and libraries"}, |
|
{"name": "programming", "description": "General programming Q&A"}, |
|
{"name": "legal", "description": "Legal"}, |
|
{"name": "healthcare", "description": "Healthcare and medical related"}, |
|
] |
|
|
|
departments = { |
|
"code_generation": ("π»", "Code Generation"), |
|
"bug_fixing": ("π", "Bug Fixing"), |
|
"performance_optimization": ("β‘", "Performance Optimization"), |
|
"api_help": ("π", "API Help"), |
|
"programming": ("π", "Programming"), |
|
"legal": ("βοΈ", "Legal"), |
|
"healthcare": ("π©Ί", "Healthcare"), |
|
"other": ("β", "Other / General Inquiry"), |
|
} |
|
|
|
|
|
TASK_INSTRUCTION = """ |
|
You are a helpful assistant designed to find the best suited route. You are provided with route description within <routes></routes> XML tags: |
|
<routes> |
|
{routes} |
|
</routes> |
|
|
|
<conversation> |
|
{conversation} |
|
</conversation> |
|
""" |
|
|
|
FORMAT_PROMPT = """ |
|
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction: |
|
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}. |
|
2. You must analyze the route descriptions and find the best match route for user latest intent. |
|
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>. |
|
|
|
Based on your analysis, provide your response in the following JSON format: |
|
{"route": "route_name"} |
|
""" |
|
|
|
def format_prompt(route_config, conversation): |
|
return TASK_INSTRUCTION.format( |
|
routes=json.dumps(route_config), conversation=json.dumps(conversation) |
|
) + FORMAT_PROMPT |
|
|
|
def parse_route(response_text): |
|
try: |
|
start = response_text.find("{") |
|
end = response_text.rfind("}") + 1 |
|
return ast.literal_eval(response_text[start:end]).get("route", "other") |
|
except Exception as e: |
|
print("Parsing failed:", e) |
|
return "other" |
|
|
|
def init_state(): |
|
avatar_emojis = { |
|
ARCH_ROUTER: "https://avatars.githubusercontent.com/u/112724757?s=200&v=4", |
|
"code_generation": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4bb.png", |
|
"bug_fixing": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f41e.png", |
|
"performance_optimization": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/26a1.png", |
|
"api_help": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f50c.png", |
|
"programming": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4da.png", |
|
"legal": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2696.png", |
|
"healthcare": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1fa7a.png", |
|
"other": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2753.png", |
|
} |
|
return { |
|
"messages": [], |
|
"participants": [ARCH_ROUTER] + list(departments.keys()), |
|
"currentSpeaker": None, |
|
"thinking": [], |
|
"showBubbles": [ARCH_ROUTER], |
|
"avatarImages": avatar_emojis, |
|
} |
|
|
|
def route_and_visualize(user_input_text, rt_state, chat_history): |
|
chat_history = chat_history or [] |
|
rt_state = rt_state or {"messages": []} |
|
chat_history.append(("User", user_input_text)) |
|
|
|
|
|
rt_state["messages"] = [{"speaker": ARCH_ROUTER, "text": "π Identifying route, please wait..."}] |
|
yield rt_state, chat_history, rt_state, gr.update(interactive=False) |
|
|
|
|
|
conversation = [{"role": "user", "content": user_input_text}] |
|
route_prompt = format_prompt(route_config, conversation) |
|
input_ids = tokenizer.apply_chat_template( |
|
[{"role": "user", "content": route_prompt}], |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
with torch.no_grad(): |
|
output = model.generate(input_ids=input_ids, max_new_tokens=512) |
|
|
|
prompt_len = input_ids.shape[1] |
|
response = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True).strip() |
|
print("MODEL RAW:", response) |
|
route = parse_route(response) |
|
|
|
emoji, dept_name = departments.get(route, departments["other"]) |
|
|
|
|
|
rt_state["messages"][0] = { |
|
"speaker": ARCH_ROUTER, |
|
"text": f"π Identified department: **{dept_name}**. Forwarding task..." |
|
} |
|
chat_history.append((ARCH_ROUTER, f"π Identified department: {dept_name}. Forwarding task...")) |
|
yield rt_state, chat_history, rt_state, gr.update(interactive=False) |
|
|
|
|
|
time.sleep(3) |
|
rt_state["messages"].extend([ |
|
{"speaker": route, "text": f"{emoji} {dept_name} simulation is processing your request in {WAIT_DEPARTMENT} secs..."}, |
|
{"speaker": ARCH_ROUTER, "text": "β³ Waiting for department to respond..."} |
|
]) |
|
rt_state["showBubbles"] = [ARCH_ROUTER, route] |
|
yield rt_state, chat_history, rt_state, gr.update(interactive=False) |
|
|
|
|
|
time.sleep(WAIT_DEPARTMENT) |
|
rt_state["messages"][-2]["text"] = f"β
{dept_name} completed the task." |
|
rt_state["messages"][-1]["text"] = f"β
{dept_name} department has completed the task." |
|
chat_history.append((ARCH_ROUTER, f"β
{dept_name} department completed the task.")) |
|
yield rt_state, chat_history, rt_state, gr.update(interactive=False) |
|
|
|
|
|
rt_state["showBubbles"] = [ARCH_ROUTER] |
|
yield rt_state, chat_history, rt_state, gr.update(interactive=False) |
|
|
|
|
|
time.sleep(WAIT_SYSTEM) |
|
rt_state["messages"].append({"speaker": ARCH_ROUTER, "text": "Arch Router is ready to discuss."}) |
|
yield rt_state, chat_history, rt_state, gr.update(interactive=True) |
|
|
|
|
|
with gr.Blocks(title="Arch Router Simulation: Smart Department Dispatcher", theme=gr.themes.Ocean()) as demo: |
|
gr.Markdown( |
|
""" |
|
## π§ Arch Router Simulation: Smart Department Dispatcher |
|
**This is a demo simulation of <a href="https://huggingface.co/katanemo/Arch-Router-1.5B" target="_blank">katanemo/Arch-Router-1.5B</a>.** |
|
**Kindly refer official documentation for more details** |
|
|
|
* See how Arch Router identifies the best route **(or Domain β the high-level category)** based on user prompt and take desired **Action (specific type of operation user wants to perform)** by forwarding it to respective department. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
rt_state = gr.State(init_state()) |
|
chat_state = gr.State([]) |
|
roundtable = consilium_roundtable(value=init_state()) |
|
|
|
with gr.Column(scale=1): |
|
chatbot = gr.Chatbot(label="Chat History", max_height=300) |
|
textbox = gr.Textbox(placeholder="Describe your issue...", label="Ask Arch Router") |
|
submit_btn = gr.Button("Submit") |
|
|
|
example_inputs = [ |
|
"How do I optimize this loop in Python?", |
|
"Generate a function to sort an array in python", |
|
"Help me anonymize patient health records before storing them", |
|
"I'm getting a TypeError in following code", |
|
"Do I need to include attribution for MIT-licensed software?", |
|
"How do I connect to external API from this code?" |
|
] |
|
|
|
|
|
for trigger in (textbox.submit, submit_btn.click): |
|
trigger( |
|
route_and_visualize, |
|
inputs=[textbox, rt_state, chat_state], |
|
outputs=[roundtable, chatbot, rt_state, textbox], |
|
concurrency_limit=1 |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=example_inputs, |
|
inputs=textbox, |
|
label="Try one of these examples" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|