import argparse import spaces import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct") parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--do_sample", action="store_true") # This allows ignoring unrecognized arguments, e.g., from Jupyter return parser.parse_known_args() def load_model(model_name): """Load model and tokenizer from Hugging Face.""" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto" ) return model, tokenizer def generate_reply(model, tokenizer, prompt, max_length, do_sample): """Generate text from the model given a prompt.""" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # We’re returning just the final string; no streaming here output_tokens = model.generate( **inputs, max_length=max_length, do_sample=do_sample ) return tokenizer.decode(output_tokens[0], skip_special_tokens=True) def main(): args, _ = get_args() model, tokenizer = load_model(args.model) @spaces.GPU def respond(user_message, chat_history): """ Gradio expects a function that takes the last user message and the conversation history, then returns the updated history. chat_history is a list of (user_message, bot_reply) pairs. """ # Build a single text prompt from the conversation so far prompt = "" for (old_user_msg, old_bot_msg) in chat_history: prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n" # Add the new user query prompt += f"User: {user_message}\nBot:" # Generate the response bot_message = generate_reply( model=model, tokenizer=tokenizer, prompt=prompt, max_length=args.max_length, do_sample=args.do_sample ) # In many cases, the model output will contain the entire prompt again, # so we can strip that off or just let it show. If you see repeated # text, you can try to remove the prompt prefix from bot_message. if bot_message.startswith(prompt): bot_message = bot_message[len(prompt):] # Append the new user-message and bot-response to the history chat_history.append((user_message, bot_message)) return chat_history, chat_history # Define the Gradio interface with gr.Blocks() as demo: gr.Markdown("

Chat with Your Model

") # A Chatbot component that will display the conversation chatbot = gr.Chatbot(label="Chat") # A text box for user input user_input = gr.Textbox( show_label=False, placeholder="Type your message here and press Enter" ) # A button to clear the conversation clear_button = gr.Button("Clear") # When the user hits Enter in the textbox, call 'respond' # - Inputs: [user_input, chatbot] (the last user message and history) # - Outputs: [chatbot, chatbot] (updates the chatbot display and history) user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot]) # Define a helper function for clearing def clear_conversation(): return [], [] # When "Clear" is clicked, reset the conversation clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot]) # Launch the Gradio app demo.launch() if __name__ == "__main__": main()