File size: 3,885 Bytes
9810ea7
6ee09b1
9810ea7
9fbf2d1
9810ea7
08d30fe
9810ea7
 
 
 
 
 
 
96784fc
9810ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fbf2d1
9810ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364cb51
9810ea7
 
 
 
 
364cb51
9810ea7
 
 
364cb51
9810ea7
 
 
 
 
 
 
 
 
 
 
9fbf2d1
9810ea7
 
 
364cb51
9810ea7
 
 
 
 
 
 
 
 
 
 
a2a8e37
9810ea7
96784fc
9810ea7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import spaces
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct")
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--do_sample", action="store_true")
    # This allows ignoring unrecognized arguments, e.g., from Jupyter
    return parser.parse_known_args()

def load_model(model_name):
    """Load model and tokenizer from Hugging Face."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    return model, tokenizer

def generate_reply(model, tokenizer, prompt, max_length, do_sample):
    """Generate text from the model given a prompt."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    # We’re returning just the final string; no streaming here
    output_tokens = model.generate(
        **inputs,
        max_length=max_length,
        do_sample=do_sample
    )
    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)

@sapces.GPU
def main():
    args, _ = get_args()
    model, tokenizer = load_model(args.model)

    def respond(user_message, chat_history):
        """
        Gradio expects a function that takes the last user message and the 
        conversation history, then returns the updated history.
        
        chat_history is a list of (user_message, bot_reply) pairs.
        """
        # Build a single text prompt from the conversation so far
        prompt = ""
        for (old_user_msg, old_bot_msg) in chat_history:
            prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n"
        # Add the new user query
        prompt += f"User: {user_message}\nBot:"

        # Generate the response
        bot_message = generate_reply(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            max_length=args.max_length,
            do_sample=args.do_sample
        )

        # In many cases, the model output will contain the entire prompt again, 
        # so we can strip that off or just let it show. If you see repeated 
        # text, you can try to remove the prompt prefix from bot_message.
        if bot_message.startswith(prompt):
            bot_message = bot_message[len(prompt):]

        # Append the new user-message and bot-response to the history
        chat_history.append((user_message, bot_message))
        return chat_history, chat_history

    # Define the Gradio interface
    with gr.Blocks() as demo:
        gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>")
        
        # A Chatbot component that will display the conversation
        chatbot = gr.Chatbot(label="Chat")
        
        # A text box for user input
        user_input = gr.Textbox(
            show_label=False, 
            placeholder="Type your message here and press Enter"
        )
        
        # A button to clear the conversation
        clear_button = gr.Button("Clear")

        # When the user hits Enter in the textbox, call 'respond'
        #   - Inputs: [user_input, chatbot]  (the last user message and history)
        #   - Outputs: [chatbot, chatbot]    (updates the chatbot display and history)
        user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot])
        
        # Define a helper function for clearing
        def clear_conversation():
            return [], []
        
        # When "Clear" is clicked, reset the conversation
        clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot])

    # Launch the Gradio app
    demo.launch()

if __name__ == "__main__":
    main()