tejasashinde's picture
Updated app.py
d2ed8d0
import json
import time
import ast
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from gradio_consilium_roundtable import consilium_roundtable
# === Constants ===
MODEL_NAME = "katanemo/Arch-Router-1.5B"
ARCH_ROUTER = "Arch Router"
WAIT_DEPARTMENT = 5
WAIT_SYSTEM = 5
# === Load model/tokenizer ===
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# === Route Definitions ===
route_config = [
{"name": "code_generation", "description": "Generating code based on prompts"},
{"name": "bug_fixing", "description": "Fixing errors or bugs in code"},
{"name": "performance_optimization", "description": "Improving code performance"},
{"name": "api_help", "description": "Assisting with APIs and libraries"},
{"name": "programming", "description": "General programming Q&A"},
{"name": "legal", "description": "Legal"},
{"name": "healthcare", "description": "Healthcare and medical related"},
]
departments = {
"code_generation": ("πŸ’»", "Code Generation"),
"bug_fixing": ("🐞", "Bug Fixing"),
"performance_optimization": ("⚑", "Performance Optimization"),
"api_help": ("πŸ”Œ", "API Help"),
"programming": ("πŸ“š", "Programming"),
"legal": ("βš–οΈ", "Legal"),
"healthcare": ("🩺", "Healthcare"),
"other": ("❓", "Other / General Inquiry"),
}
# === Prompt Formatting ===
TASK_INSTRUCTION = """
You are a helpful assistant designed to find the best suited route. You are provided with route description within <routes></routes> XML tags:
<routes>
{routes}
</routes>
<conversation>
{conversation}
</conversation>
"""
FORMAT_PROMPT = """
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON format:
{"route": "route_name"}
"""
def format_prompt(route_config, conversation):
return TASK_INSTRUCTION.format(
routes=json.dumps(route_config), conversation=json.dumps(conversation)
) + FORMAT_PROMPT
def parse_route(response_text):
try:
start = response_text.find("{")
end = response_text.rfind("}") + 1
return ast.literal_eval(response_text[start:end]).get("route", "other")
except Exception as e:
print("Parsing failed:", e)
return "other"
def init_state():
avatar_emojis = {
ARCH_ROUTER: "https://avatars.githubusercontent.com/u/112724757?s=200&v=4",
"code_generation": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4bb.png",
"bug_fixing": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f41e.png",
"performance_optimization": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/26a1.png",
"api_help": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f50c.png",
"programming": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1f4da.png",
"legal": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2696.png",
"healthcare": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/1fa7a.png",
"other": "https://raw.githubusercontent.com/twitter/twemoji/master/assets/72x72/2753.png",
}
return {
"messages": [],
"participants": [ARCH_ROUTER] + list(departments.keys()),
"currentSpeaker": None,
"thinking": [],
"showBubbles": [ARCH_ROUTER],
"avatarImages": avatar_emojis,
}
def route_and_visualize(user_input_text, rt_state, chat_history):
chat_history = chat_history or []
rt_state = rt_state or {"messages": []}
chat_history.append(("User", user_input_text))
# Step 1: Disable input and show route detection
rt_state["messages"] = [{"speaker": ARCH_ROUTER, "text": "πŸ”Ž Identifying route, please wait..."}]
yield rt_state, chat_history, rt_state, gr.update(interactive=False)
# Step 2: Prepare prompt and get route
conversation = [{"role": "user", "content": user_input_text}]
route_prompt = format_prompt(route_config, conversation)
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": route_prompt}],
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
output = model.generate(input_ids=input_ids, max_new_tokens=512)
prompt_len = input_ids.shape[1]
response = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True).strip()
print("MODEL RAW:", response)
route = parse_route(response)
emoji, dept_name = departments.get(route, departments["other"])
# Step 3: Show route identified
rt_state["messages"][0] = {
"speaker": ARCH_ROUTER,
"text": f"πŸ“Œ Identified department: **{dept_name}**. Forwarding task..."
}
chat_history.append((ARCH_ROUTER, f"πŸ“Œ Identified department: {dept_name}. Forwarding task..."))
yield rt_state, chat_history, rt_state, gr.update(interactive=False)
# Step 4: Show processing
time.sleep(3)
rt_state["messages"].extend([
{"speaker": route, "text": f"{emoji} {dept_name} simulation is processing your request in {WAIT_DEPARTMENT} secs..."},
{"speaker": ARCH_ROUTER, "text": "⏳ Waiting for department to respond..."}
])
rt_state["showBubbles"] = [ARCH_ROUTER, route]
yield rt_state, chat_history, rt_state, gr.update(interactive=False)
# Step 5: Simulate delay and complete
time.sleep(WAIT_DEPARTMENT)
rt_state["messages"][-2]["text"] = f"βœ… {dept_name} completed the task."
rt_state["messages"][-1]["text"] = f"βœ… {dept_name} department has completed the task."
chat_history.append((ARCH_ROUTER, f"βœ… {dept_name} department completed the task."))
yield rt_state, chat_history, rt_state, gr.update(interactive=False)
# Step 6: Reset visible bubbles
rt_state["showBubbles"] = [ARCH_ROUTER]
yield rt_state, chat_history, rt_state, gr.update(interactive=False)
# Step 7: System ready
time.sleep(WAIT_SYSTEM)
rt_state["messages"].append({"speaker": ARCH_ROUTER, "text": "Arch Router is ready to discuss."})
yield rt_state, chat_history, rt_state, gr.update(interactive=True)
# === Gradio UI ===
with gr.Blocks(title="Arch Router Simulation: Smart Department Dispatcher", theme=gr.themes.Ocean()) as demo:
gr.Markdown(
"""
## 🧭 Arch Router Simulation: Smart Department Dispatcher
**This is a demo simulation of <a href="https://huggingface.co/katanemo/Arch-Router-1.5B" target="_blank">katanemo/Arch-Router-1.5B</a>.**
**Kindly refer official documentation for more details**
* See how Arch Router identifies the best route **(or Domain – the high-level category)** based on user prompt and take desired **Action (specific type of operation user wants to perform)** by forwarding it to respective department.
"""
)
with gr.Row():
with gr.Column(scale=2):
rt_state = gr.State(init_state())
chat_state = gr.State([])
roundtable = consilium_roundtable(value=init_state())
with gr.Column(scale=1):
chatbot = gr.Chatbot(label="Chat History", max_height=300)
textbox = gr.Textbox(placeholder="Describe your issue...", label="Ask Arch Router")
submit_btn = gr.Button("Submit")
example_inputs = [
"How do I optimize this loop in Python?",
"Generate a function to sort an array in python",
"Help me anonymize patient health records before storing them",
"I'm getting a TypeError in following code",
"Do I need to include attribution for MIT-licensed software?",
"How do I connect to external API from this code?"
]
# Trigger submission via Enter or Button
for trigger in (textbox.submit, submit_btn.click):
trigger(
route_and_visualize,
inputs=[textbox, rt_state, chat_state],
outputs=[roundtable, chatbot, rt_state, textbox],
concurrency_limit=1
)
# Example block
gr.Examples(
examples=example_inputs,
inputs=textbox,
label="Try one of these examples"
)
if __name__ == "__main__":
demo.launch()