prithivMLmods commited on
Commit
9fbf2d1
·
verified ·
1 Parent(s): 0e7f75b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -172
app.py CHANGED
@@ -1,182 +1,137 @@
1
- import gradio as gr
2
  import torch
3
- import spaces
4
- from transformers import AutoModel, AutoProcessor
5
-
6
- from gender_classification import gender_classification
7
- from emotion_classification import emotion_classification
8
- from dog_breed import dog_breed_classification
9
- from deepfake_quality import deepfake_classification
10
- from gym_workout_classification import workout_classification
11
- from augmented_waste_classifier import waste_classification
12
- from age_classification import age_classification
13
- from mnist_digits import classify_digit
14
- from fashion_mnist_cloth import fashion_mnist_classification
15
- from indian_western_food_classify import food_classification
16
- from bird_species import bird_classification
17
- from alphabet_sign_language_detection import sign_language_classification
18
- from rice_leaf_disease import classify_leaf_disease
19
- from traffic_density import traffic_density_classification
20
- from clip_art import clipart_classification
21
- from multisource_121 import multisource_classification
22
- from painting_126 import painting_classification
23
- from sketch_126 import sketch_classification # New import
24
-
25
- # Main classification function for multi-model classification.
26
- def classify(image, model_name):
27
- if model_name == "gender":
28
- return gender_classification(image)
29
- elif model_name == "emotion":
30
- return emotion_classification(image)
31
- elif model_name == "dog breed":
32
- return dog_breed_classification(image)
33
- elif model_name == "deepfake":
34
- return deepfake_classification(image)
35
- elif model_name == "gym workout":
36
- return workout_classification(image)
37
- elif model_name == "waste":
38
- return waste_classification(image)
39
- elif model_name == "age":
40
- return age_classification(image)
41
- elif model_name == "mnist":
42
- return classify_digit(image)
43
- elif model_name == "fashion_mnist":
44
- return fashion_mnist_classification(image)
45
- elif model_name == "food":
46
- return food_classification(image)
47
- elif model_name == "bird":
48
- return bird_classification(image)
49
- elif model_name == "leaf disease":
50
- return classify_leaf_disease(image)
51
- elif model_name == "sign language":
52
- return sign_language_classification(image)
53
- elif model_name == "traffic density":
54
- return traffic_density_classification(image)
55
- elif model_name == "clip art":
56
- return clipart_classification(image)
57
- elif model_name == "multisource":
58
- return multisource_classification(image)
59
- elif model_name == "painting":
60
- return painting_classification(image)
61
- elif model_name == "sketch": # New option
62
- return sketch_classification(image)
63
- else:
64
- return {"Error": "No model selected"}
65
-
66
- # Function to update the selected model and button styles.
67
- def select_model(model_name):
68
- model_variants = {
69
- "gender": "secondary", "emotion": "secondary", "dog breed": "secondary", "deepfake": "secondary",
70
- "gym workout": "secondary", "waste": "secondary", "age": "secondary", "mnist": "secondary",
71
- "fashion_mnist": "secondary", "food": "secondary", "bird": "secondary", "leaf disease": "secondary",
72
- "sign language": "secondary", "traffic density": "secondary", "clip art": "secondary",
73
- "multisource": "secondary", "painting": "secondary", "sketch": "secondary" # New model variant
74
- }
75
- model_variants[model_name] = "primary"
76
- return (model_name, *(gr.update(variant=model_variants[key]) for key in model_variants))
77
 
78
- # Zero-Shot Classification Setup (SigLIP models)
79
- sg1_ckpt = "google/siglip-so400m-patch14-384"
80
- siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="cpu").eval()
81
- siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- sg2_ckpt = "google/siglip2-so400m-patch14-384"
84
- siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="cpu").eval()
85
- siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
86
 
87
- @spaces.GPU
88
- def postprocess_siglip(sg1_probs, sg2_probs, labels):
89
- sg1_output = {labels[i]: sg1_probs[0][i].item() for i in range(len(labels))}
90
- sg2_output = {labels[i]: sg2_probs[0][i].item() for i in range(len(labels))}
91
- return sg1_output, sg2_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- def siglip_detector(image, texts):
94
- sg1_inputs = siglip1_processor(
95
- text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
96
- ).to("cpu")
97
- sg2_inputs = siglip2_processor(
98
- text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
99
- ).to("cpu")
 
 
 
 
 
100
  with torch.no_grad():
101
- sg1_outputs = siglip1_model(**sg1_inputs)
102
- sg2_outputs = siglip2_model(**sg2_inputs)
103
- sg1_logits_per_image = sg1_outputs.logits_per_image
104
- sg2_logits_per_image = sg2_outputs.logits_per_image
105
- sg1_probs = torch.sigmoid(sg1_logits_per_image)
106
- sg2_probs = torch.sigmoid(sg2_logits_per_image)
107
- return sg1_probs, sg2_probs
108
-
109
- def infer(image, candidate_labels):
110
- candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
111
- sg1_probs, sg2_probs = siglip_detector(image, candidate_labels)
112
- return postprocess_siglip(sg1_probs, sg2_probs, labels=candidate_labels)
113
-
114
- # Build the Gradio Interface with two tabs.
115
- with gr.Blocks(theme="YTheme/Minecraft") as demo:
116
- gr.Markdown("# Multi-Domain & Zero-Shot Image Classification")
117
 
118
- with gr.Tabs():
119
- # Tab 1: Multi-Model Classification
120
- with gr.Tab("Multi-Domain Classification"):
121
- with gr.Sidebar():
122
- gr.Markdown("# Choose Domain")
123
- with gr.Row():
124
- age_btn = gr.Button("Age Classification", variant="primary")
125
- gender_btn = gr.Button("Gender Classification", variant="secondary")
126
- emotion_btn = gr.Button("Emotion Classification", variant="secondary")
127
- gym_workout_btn = gr.Button("Gym Workout Classification", variant="secondary")
128
- dog_breed_btn = gr.Button("Dog Breed Classification", variant="secondary")
129
- bird_btn = gr.Button("Bird Species Classification", variant="secondary")
130
- waste_btn = gr.Button("Waste Classification", variant="secondary")
131
- deepfake_btn = gr.Button("Deepfake Quality Test", variant="secondary")
132
- traffic_density_btn = gr.Button("Traffic Density", variant="secondary")
133
- sign_language_btn = gr.Button("Alphabet Sign Language", variant="secondary")
134
- clip_art_btn = gr.Button("Clip Art 126", variant="secondary")
135
- mnist_btn = gr.Button("Digit Classify (0-9)", variant="secondary")
136
- fashion_mnist_btn = gr.Button("Fashion MNIST (only cloth)", variant="secondary")
137
- food_btn = gr.Button("Indian/Western Food Type", variant="secondary")
138
- leaf_disease_btn = gr.Button("Rice Leaf Disease", variant="secondary")
139
- multisource_btn = gr.Button("Multi Source 121", variant="secondary")
140
- painting_btn = gr.Button("Painting 126", variant="secondary")
141
- sketch_btn = gr.Button("Sketch 126", variant="secondary")
142
-
143
- selected_model = gr.State("age")
144
- gr.Markdown("### Current Model:")
145
- model_display = gr.Textbox(value="age", interactive=False)
146
- selected_model.change(lambda m: m, selected_model, model_display)
147
 
148
- buttons = [
149
- gender_btn, emotion_btn, dog_breed_btn, deepfake_btn, gym_workout_btn, waste_btn,
150
- age_btn, mnist_btn, fashion_mnist_btn, food_btn, bird_btn, leaf_disease_btn,
151
- sign_language_btn, traffic_density_btn, clip_art_btn, multisource_btn, painting_btn, sketch_btn # Include new button
152
- ]
153
- model_names = [
154
- "gender", "emotion", "dog breed", "deepfake", "gym workout", "waste",
155
- "age", "mnist", "fashion_mnist", "food", "bird", "leaf disease",
156
- "sign language", "traffic density", "clip art", "multisource", "painting", "sketch" # New model name
157
- ]
158
 
159
- for btn, name in zip(buttons, model_names):
160
- btn.click(fn=lambda n=name: select_model(n), inputs=[], outputs=[selected_model] + buttons)
161
-
162
- with gr.Row():
163
- with gr.Column():
164
- image_input = gr.Image(type="numpy", label="Upload Image")
165
- analyze_btn = gr.Button("Classify / Predict")
166
- output_label = gr.Label(label="Prediction Scores")
167
- analyze_btn.click(fn=classify, inputs=[image_input, selected_model], outputs=output_label)
168
-
169
- # Tab 2: Zero-Shot Classification (SigLIP)
170
- with gr.Tab("Zero-Shot Classification"):
171
- gr.Markdown("## Compare SigLIP 1 and SigLIP 2 on Zero-Shot Classification")
172
- with gr.Row():
173
- with gr.Column():
174
- zs_image_input = gr.Image(type="pil", label="Upload Image")
175
- zs_text_input = gr.Textbox(label="Input a list of labels (comma separated)")
176
- zs_run_button = gr.Button("Run")
177
- with gr.Column():
178
- siglip1_output = gr.Label(label="SigLIP 1 Output", num_top_classes=3)
179
- siglip2_output = gr.Label(label="SigLIP 2 Output", num_top_classes=3)
180
- zs_run_button.click(fn=infer, inputs=[zs_image_input, zs_text_input], outputs=[siglip1_output, siglip2_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- demo.launch()
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import gradio as gr
4
+ from snac import SNAC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ def redistribute_codes(row):
7
+ """
8
+ Convert a sequence of token codes into an audio waveform using SNAC.
9
+ The code assumes each 7 tokens represent one group of instructions.
10
+ """
11
+ row_length = row.size(0)
12
+ new_length = (row_length // 7) * 7
13
+ trimmed_row = row[:new_length]
14
+ code_list = [t - 128266 for t in trimmed_row]
15
+
16
+ layer_1, layer_2, layer_3 = [], [], []
17
+
18
+ for i in range((len(code_list) + 1) // 7):
19
+ layer_1.append(code_list[7 * i][None])
20
+ layer_2.append(code_list[7 * i + 1][None] - 4096)
21
+ layer_3.append(code_list[7 * i + 2][None] - (2 * 4096))
22
+ layer_3.append(code_list[7 * i + 3][None] - (3 * 4096))
23
+ layer_2.append(code_list[7 * i + 4][None] - (4 * 4096))
24
+ layer_3.append(code_list[7 * i + 5][None] - (5 * 4096))
25
+ layer_3.append(code_list[7 * i + 6][None] - (6 * 4096))
26
+
27
+ with torch.no_grad():
28
+ codes = [
29
+ torch.concat(layer_1),
30
+ torch.concat(layer_2),
31
+ torch.concat(layer_3)
32
+ ]
33
+ for i in range(len(codes)):
34
+ codes[i][codes[i] < 0] = 0
35
+ codes[i] = codes[i][None]
36
+
37
+ audio_hat = snac_model.decode(codes)
38
+ return audio_hat.cpu()[0, 0]
39
 
40
+ # Load the SNAC model (shared by all)
41
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to("cuda")
 
42
 
43
+ # Load all the single-speaker language models
44
+ models = {
45
+ "Luna": {
46
+ "tokenizer": AutoTokenizer.from_pretrained('prithivMLmods/Llama-3B-Mono-Luna'),
47
+ "model": AutoModelForCausalLM.from_pretrained('prithivMLmods/Llama-3B-Mono-Luna', torch_dtype=torch.bfloat16).cuda()
48
+ },
49
+ "Ceylia": {
50
+ "tokenizer": AutoTokenizer.from_pretrained('prithivMLmods/Llama-3B-Mono-Ceylia'),
51
+ "model": AutoModelForCausalLM.from_pretrained('prithivMLmods/Llama-3B-Mono-Ceylia', torch_dtype=torch.bfloat16).cuda()
52
+ },
53
+ "Cooper": {
54
+ "tokenizer": AutoTokenizer.from_pretrained('prithivMLmods/Llama-3B-Mono-Cooper'),
55
+ "model": AutoModelForCausalLM.from_pretrained('prithivMLmods/Llama-3B-Mono-Cooper', torch_dtype=torch.bfloat16).cuda()
56
+ },
57
+ "Jim": {
58
+ "tokenizer": AutoTokenizer.from_pretrained('prithivMLmods/Llama-3B-Mono-Jim'),
59
+ "model": AutoModelForCausalLM.from_pretrained('prithivMLmods/Llama-3B-Mono-Jim', torch_dtype=torch.bfloat16).cuda()
60
+ },
61
+ }
62
 
63
+ def generate_audio(text, temperature, top_p, max_new_tokens, model_name):
64
+ """
65
+ Given input text and model parameters, generate speech audio using the chosen model.
66
+ """
67
+ # Retrieve the chosen tokenizer and model
68
+ chosen = models[model_name]
69
+ tokenizer = chosen["tokenizer"]
70
+ model = chosen["model"]
71
+
72
+ prompt = f'<custom_token_3><|begin_of_text|>{text}<|eot_id|><custom_token_4><custom_token_5><custom_token_1>'
73
+ input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').to('cuda')
74
+
75
  with torch.no_grad():
76
+ generated_ids = model.generate(
77
+ **input_ids,
78
+ max_new_tokens=max_new_tokens,
79
+ do_sample=True,
80
+ temperature=temperature,
81
+ top_p=top_p,
82
+ repetition_penalty=1.1,
83
+ num_return_sequences=1,
84
+ eos_token_id=128258,
85
+ )
 
 
 
 
 
 
86
 
87
+ row = generated_ids[0, input_ids['input_ids'].shape[1]:]
88
+ y_tensor = redistribute_codes(row)
89
+ y_np = y_tensor.detach().cpu().numpy()
90
+ return (24000, y_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Example texts with emotion tokens
93
+ example_texts = [
94
+ ["Hi, my name is Alex. <laugh> It's a wonderful day! <chuckle> I love coding."],
95
+ ["I woke up feeling sleepy. <yawn> I need coffee! <sniffle> But I'm ready to work."],
96
+ ["Oh no, I forgot my keys! <groan> <uhm> Maybe I'll try again later. <sigh>"],
97
+ ["This is amazing! <gasp> Really, it's fantastic. <giggles>"]
98
+ ]
 
 
 
99
 
100
+ # Gradio Interface
101
+ with gr.Blocks() as demo:
102
+ # Sidebar for model selection
103
+ with gr.Sidebar():
104
+ gr.Markdown("# Choose Model")
105
+ model_choice = gr.Dropdown(choices=list(models.keys()), value="Luna", label="Model")
106
+
107
+ gr.Markdown("# Single Speaker Audio Generation")
108
+ gr.Markdown("Generate speech audio using one of the single-speaker models. Use the examples below to see how emotion tokens like `<laugh>`, `<chuckle>`, `<sigh>`, etc. can be incorporated.")
109
+
110
+ with gr.Row():
111
+ text_input = gr.Textbox(lines=4, label="Input Text")
112
+
113
+ # Examples with emotion tokens
114
+ gr.Examples(
115
+ examples=example_texts,
116
+ inputs=text_input,
117
+ label="Emotion Examples",
118
+ cache_examples=False
119
+ )
120
+
121
+ with gr.Row():
122
+ temp_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=0.9, label="Temperature")
123
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.8, label="Top-p")
124
+ tokens_slider = gr.Slider(minimum=100, maximum=3500, step=50, value=1200, label="Max New Tokens")
125
+
126
+ output_audio = gr.Audio(type="numpy", label="Generated Audio")
127
+ generate_button = gr.Button("Generate Audio")
128
+
129
+ # Pass the selected model name along with other parameters
130
+ generate_button.click(
131
+ fn=generate_audio,
132
+ inputs=[text_input, temp_slider, top_p_slider, tokens_slider, model_choice],
133
+ outputs=output_audio
134
+ )
135
 
136
+ if __name__ == "__main__":
137
+ demo.launch()