prithivMLmods commited on
Commit
31393b3
·
verified ·
1 Parent(s): e08c3aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -19
app.py CHANGED
@@ -23,7 +23,8 @@ from rice_leaf_disease import classify_leaf_disease
23
  from traffic_density import traffic_density_classification
24
  from clip_art import clipart_classification
25
  from multisource_121 import multisource_classification
26
- from painting_126 import painting_classification # New import
 
27
 
28
  # Gradio-Theme
29
  class Seafoam(Base):
@@ -32,7 +33,7 @@ class Seafoam(Base):
32
  *,
33
  primary_hue: colors.Color | str = colors.emerald,
34
  secondary_hue: colors.Color | str = colors.blue,
35
- neutral_hue: colors.Color | str = colors.gray,
36
  spacing_size: sizes.Size | str = sizes.spacing_md,
37
  radius_size: sizes.Size | str = sizes.radius_md,
38
  text_size: sizes.Size | str = sizes.text_lg,
@@ -61,6 +62,21 @@ class Seafoam(Base):
61
  font=font,
62
  font_mono=font_mono,
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  seafoam = Seafoam()
66
 
@@ -98,8 +114,10 @@ def classify(image, model_name):
98
  return clipart_classification(image)
99
  elif model_name == "multisource":
100
  return multisource_classification(image)
101
- elif model_name == "painting": # New option
102
  return painting_classification(image)
 
 
103
  else:
104
  return {"Error": "No model selected"}
105
 
@@ -110,13 +128,12 @@ def select_model(model_name):
110
  "gym workout": "secondary", "waste": "secondary", "age": "secondary", "mnist": "secondary",
111
  "fashion_mnist": "secondary", "food": "secondary", "bird": "secondary", "leaf disease": "secondary",
112
  "sign language": "secondary", "traffic density": "secondary", "clip art": "secondary",
113
- "multisource": "secondary", "painting": "secondary" # New model variant
114
  }
115
  model_variants[model_name] = "primary"
116
  return (model_name, *(gr.update(variant=model_variants[key]) for key in model_variants))
117
 
118
  # Zero-Shot Classification Setup (SigLIP models)
119
- # Load the SigLIP models and processors
120
  sg1_ckpt = "google/siglip-so400m-patch14-384"
121
  siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="cpu").eval()
122
  siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
@@ -125,7 +142,6 @@ sg2_ckpt = "google/siglip2-so400m-patch14-384"
125
  siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="cpu").eval()
126
  siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
127
 
128
- # Utilities for zero-shot classification.
129
  @spaces.GPU
130
  def postprocess_siglip(sg1_probs, sg2_probs, labels):
131
  sg1_output = {labels[i]: sg1_probs[0][i].item() for i in range(len(labels))}
@@ -166,21 +182,22 @@ with gr.Blocks(theme=seafoam) as demo:
166
  age_btn = gr.Button("Age Classification", variant="primary")
167
  gender_btn = gr.Button("Gender Classification", variant="secondary")
168
  emotion_btn = gr.Button("Emotion Classification", variant="secondary")
169
- dog_breed_btn = gr.Button("Dog Breed Classification", variant="secondary")
170
- deepfake_btn = gr.Button("Deepfake vs Real", variant="secondary")
171
  gym_workout_btn = gr.Button("Gym Workout Classification", variant="secondary")
 
 
172
  waste_btn = gr.Button("Waste Classification", variant="secondary")
 
 
 
173
  mnist_btn = gr.Button("Digit Classify (0-9)", variant="secondary")
174
- fashion_mnist_btn = gr.Button("Fashion MNIST Classification", variant="secondary")
175
- food_btn = gr.Button("Indian/Western Food", variant="secondary")
176
- bird_btn = gr.Button("Bird Species Classification", variant="secondary")
177
  leaf_disease_btn = gr.Button("Rice Leaf Disease", variant="secondary")
178
- sign_language_btn = gr.Button("Alphabet Sign Language", variant="secondary")
179
- traffic_density_btn = gr.Button("Traffic Density", variant="secondary")
180
- clip_art_btn = gr.Button("Art Classification", variant="secondary")
181
- multisource_btn = gr.Button("Multi-Source Classification", variant="secondary")
182
- painting_btn = gr.Button("Painting Classification", variant="secondary") # New button
183
-
184
  selected_model = gr.State("age")
185
  gr.Markdown("### Current Model:")
186
  model_display = gr.Textbox(value="age", interactive=False)
@@ -189,12 +206,12 @@ with gr.Blocks(theme=seafoam) as demo:
189
  buttons = [
190
  gender_btn, emotion_btn, dog_breed_btn, deepfake_btn, gym_workout_btn, waste_btn,
191
  age_btn, mnist_btn, fashion_mnist_btn, food_btn, bird_btn, leaf_disease_btn,
192
- sign_language_btn, traffic_density_btn, clip_art_btn, multisource_btn, painting_btn # Include new button
193
  ]
194
  model_names = [
195
  "gender", "emotion", "dog breed", "deepfake", "gym workout", "waste",
196
  "age", "mnist", "fashion_mnist", "food", "bird", "leaf disease",
197
- "sign language", "traffic density", "clip art", "multisource", "painting" # New model name
198
  ]
199
 
200
  for btn, name in zip(buttons, model_names):
 
23
  from traffic_density import traffic_density_classification
24
  from clip_art import clipart_classification
25
  from multisource_121 import multisource_classification
26
+ from painting_126 import painting_classification
27
+ from sketch_126 import sketch_classification # New import
28
 
29
  # Gradio-Theme
30
  class Seafoam(Base):
 
33
  *,
34
  primary_hue: colors.Color | str = colors.emerald,
35
  secondary_hue: colors.Color | str = colors.blue,
36
+ neutral_hue: colors.Color | str = colors.blue,
37
  spacing_size: sizes.Size | str = sizes.spacing_md,
38
  radius_size: sizes.Size | str = sizes.radius_md,
39
  text_size: sizes.Size | str = sizes.text_lg,
 
62
  font=font,
63
  font_mono=font_mono,
64
  )
65
+ super().set(
66
+ body_background_fill="repeating-linear-gradient(45deg, *primary_200, *primary_200 10px, *primary_50 10px, *primary_50 20px)",
67
+ body_background_fill_dark="repeating-linear-gradient(45deg, *primary_800, *primary_800 10px, *primary_900 10px, *primary_900 20px)",
68
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
69
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
70
+ button_primary_text_color="white",
71
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
72
+ slider_color="*secondary_300",
73
+ slider_color_dark="*secondary_600",
74
+ block_title_text_weight="600",
75
+ block_border_width="3px",
76
+ block_shadow="*shadow_drop_lg",
77
+ button_primary_shadow="*shadow_drop_lg",
78
+ button_large_padding="32px",
79
+ )
80
 
81
  seafoam = Seafoam()
82
 
 
114
  return clipart_classification(image)
115
  elif model_name == "multisource":
116
  return multisource_classification(image)
117
+ elif model_name == "painting":
118
  return painting_classification(image)
119
+ elif model_name == "sketch": # New option
120
+ return sketch_classification(image)
121
  else:
122
  return {"Error": "No model selected"}
123
 
 
128
  "gym workout": "secondary", "waste": "secondary", "age": "secondary", "mnist": "secondary",
129
  "fashion_mnist": "secondary", "food": "secondary", "bird": "secondary", "leaf disease": "secondary",
130
  "sign language": "secondary", "traffic density": "secondary", "clip art": "secondary",
131
+ "multisource": "secondary", "painting": "secondary", "sketch": "secondary" # New model variant
132
  }
133
  model_variants[model_name] = "primary"
134
  return (model_name, *(gr.update(variant=model_variants[key]) for key in model_variants))
135
 
136
  # Zero-Shot Classification Setup (SigLIP models)
 
137
  sg1_ckpt = "google/siglip-so400m-patch14-384"
138
  siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="cpu").eval()
139
  siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
 
142
  siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="cpu").eval()
143
  siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
144
 
 
145
  @spaces.GPU
146
  def postprocess_siglip(sg1_probs, sg2_probs, labels):
147
  sg1_output = {labels[i]: sg1_probs[0][i].item() for i in range(len(labels))}
 
182
  age_btn = gr.Button("Age Classification", variant="primary")
183
  gender_btn = gr.Button("Gender Classification", variant="secondary")
184
  emotion_btn = gr.Button("Emotion Classification", variant="secondary")
 
 
185
  gym_workout_btn = gr.Button("Gym Workout Classification", variant="secondary")
186
+ dog_breed_btn = gr.Button("Dog Breed Classification", variant="secondary")
187
+ bird_btn = gr.Button("Bird Species Classification", variant="secondary")
188
  waste_btn = gr.Button("Waste Classification", variant="secondary")
189
+ traffic_density_btn = gr.Button("Traffic Density", variant="secondary")
190
+ sign_language_btn = gr.Button("Alphabet Sign Language", variant="secondary")
191
+ clip_art_btn = gr.Button("Clip Art 126", variant="secondary")
192
  mnist_btn = gr.Button("Digit Classify (0-9)", variant="secondary")
193
+ fashion_mnist_btn = gr.Button("Fashion MNIST (only cloth)", variant="secondary")
194
+ food_btn = gr.Button("Indian/Western Food Type", variant="secondary")
 
195
  leaf_disease_btn = gr.Button("Rice Leaf Disease", variant="secondary")
196
+ multisource_btn = gr.Button("Multi Source 121", variant="secondary")
197
+ painting_btn = gr.Button("Painting 126", variant="secondary")
198
+ sketch_btn = gr.Button("Sketch 126", variant="secondary")
199
+ deepfake_btn = gr.Button("Deepfake vs Real", variant="secondary")
200
+
 
201
  selected_model = gr.State("age")
202
  gr.Markdown("### Current Model:")
203
  model_display = gr.Textbox(value="age", interactive=False)
 
206
  buttons = [
207
  gender_btn, emotion_btn, dog_breed_btn, deepfake_btn, gym_workout_btn, waste_btn,
208
  age_btn, mnist_btn, fashion_mnist_btn, food_btn, bird_btn, leaf_disease_btn,
209
+ sign_language_btn, traffic_density_btn, clip_art_btn, multisource_btn, painting_btn, sketch_btn # Include new button
210
  ]
211
  model_names = [
212
  "gender", "emotion", "dog breed", "deepfake", "gym workout", "waste",
213
  "age", "mnist", "fashion_mnist", "food", "bird", "leaf disease",
214
+ "sign language", "traffic density", "clip art", "multisource", "painting", "sketch" # New model name
215
  ]
216
 
217
  for btn, name in zip(buttons, model_names):