atalaydenknalbant commited on
Commit
cc78067
·
verified ·
1 Parent(s): 3142e56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -86
app.py CHANGED
@@ -36,35 +36,29 @@ except Exception:
36
 
37
  def _gpu_duration_gallery(files: List[str], *_args, **_kwargs) -> int:
38
  """Return a booking duration for a gallery job based on file count.
39
-
40
  Args:
41
  files: List of file paths. Used to scale the reservation window.
42
-
43
  Returns:
44
  Number of seconds to reserve capped at 600.
45
  """
46
  n = max(1, len(files) if files else 1)
47
- return min(600, 35 * n + 30) # 35s per image plus 30s buffer capped at 10 minutes
48
-
49
 
50
  def _gpu_duration_classify(*_args, **_kwargs) -> int:
51
  """Return a small booking duration for classification runs.
52
-
53
  Returns:
54
  Number of seconds to reserve for classification.
55
  """
56
- return 90 # small buffer for 1 query plus a handful of centroids
57
 
58
  # ---------------------------
59
- # Model loading and embedding extraction (fp32 only)
60
  # ---------------------------
61
 
62
  def _load(model_id: str) -> Tuple[AutoImageProcessor, AutoModel]:
63
  """Load processor and model then move to CUDA eval in float32.
64
-
65
  Args:
66
  model_id: Hugging Face model id to load.
67
-
68
  Returns:
69
  Tuple of processor and model on CUDA in eval mode.
70
  """
@@ -77,13 +71,10 @@ def _load(model_id: str) -> Tuple[AutoImageProcessor, AutoModel]:
77
  model.to("cuda").to(torch.float32).eval()
78
  return processor, model
79
 
80
-
81
  def _to_cuda_batchfeature(bf):
82
  """Move a BatchFeature or dict of tensors to CUDA.
83
-
84
  Args:
85
  bf: Transformers BatchFeature or a dict of tensors.
86
-
87
  Returns:
88
  BatchFeature or dict on CUDA.
89
  """
@@ -91,15 +82,11 @@ def _to_cuda_batchfeature(bf):
91
  return bf.to("cuda")
92
  return {k: v.to("cuda") for k, v in bf.items()}
93
 
94
-
95
- def _embed(image: Image.Image, model_id: str, pooling: str) -> np.ndarray:
96
  """Extract a single-image DINOv3 embedding.
97
-
98
  Args:
99
  image: Input PIL image.
100
  model_id: Backbone id from MODELS.
101
- pooling: Either "CLS" or "Mean of patch tokens".
102
-
103
  Returns:
104
  1D NumPy vector in float32.
105
  """
@@ -110,19 +97,10 @@ def _embed(image: Image.Image, model_id: str, pooling: str) -> np.ndarray:
110
  with torch.inference_mode():
111
  out = model(**bf)
112
 
113
- if pooling == "CLS":
114
- if getattr(out, "pooler_output", None) is not None:
115
- emb = out.pooler_output[0]
116
- else:
117
- emb = out.last_hidden_state[0, 0]
118
  else:
119
- if out.last_hidden_state.ndim == 3: # ViT tokens
120
- num_regs = getattr(model.config, "num_register_tokens", 0)
121
- patch_tokens = out.last_hidden_state[0, 1 + num_regs :]
122
- emb = patch_tokens.mean(dim=0)
123
- else: # Conv/backbone feature map [C,H,W]
124
- feat = out.last_hidden_state[0]
125
- emb = feat.mean(dim=(1, 2))
126
 
127
  return emb.float().cpu().numpy().astype(np.float32)
128
 
@@ -132,10 +110,8 @@ def _embed(image: Image.Image, model_id: str, pooling: str) -> np.ndarray:
132
 
133
  def _open_images_from_paths(paths: List[str]) -> List[Image.Image]:
134
  """Open many image files as RGB PIL images.
135
-
136
  Args:
137
  paths: List of file paths.
138
-
139
  Returns:
140
  List of PIL images. Files that fail to open are skipped.
141
  """
@@ -148,14 +124,11 @@ def _open_images_from_paths(paths: List[str]) -> List[Image.Image]:
148
  pass
149
  return imgs
150
 
151
-
152
  def _to_html_table(S: np.ndarray, names: List[str]) -> str:
153
  """Render a cosine similarity matrix as an HTML table.
154
-
155
  Args:
156
  S: Square matrix of cosine similarities.
157
  names: File names for header and row labels.
158
-
159
  Returns:
160
  HTML string with a scrollable table.
161
  """
@@ -177,13 +150,10 @@ def _to_html_table(S: np.ndarray, names: List[str]) -> str:
177
  """
178
  return table
179
 
180
-
181
  def _normalize_rows(X: np.ndarray) -> np.ndarray:
182
  """Normalize rows to unit norm with safe clipping.
183
-
184
  Args:
185
  X: Matrix of shape N by D.
186
-
187
  Returns:
188
  Matrix with each row divided by its L2 norm.
189
  """
@@ -195,14 +165,11 @@ def _normalize_rows(X: np.ndarray) -> np.ndarray:
195
  # ---------------------------
196
 
197
  @spaces.GPU(duration=_gpu_duration_gallery)
198
- def batch_similarity(files: List[str], model_name: str, pooling: str):
199
  """Compute pairwise cosine similarities for many images.
200
-
201
  Args:
202
  files: List of image file paths.
203
  model_name: Key from MODELS.
204
- pooling: Either CLS or Mean of patch tokens.
205
-
206
  Returns:
207
  html_table: HTML table with cosine similarities.
208
  csv_path: Path to a CSV file of the matrix.
@@ -217,7 +184,7 @@ def batch_similarity(files: List[str], model_name: str, pooling: str):
217
  imgs = _open_images_from_paths(paths)
218
  embs = []
219
  for img in imgs:
220
- e = _embed(img, model_id, pooling)
221
  embs.append(e)
222
 
223
  if len(embs) < 2:
@@ -233,47 +200,38 @@ def batch_similarity(files: List[str], model_name: str, pooling: str):
233
  return html, csv_path
234
 
235
  # ---------------------------
236
- # Image Classification using DINOv3 embeddings
237
  # ---------------------------
238
 
239
  def _init_state() -> Dict:
240
  """Create an empty classifier state.
241
-
242
  Returns:
243
- Dict with model_id pooling and classes.
244
  """
245
- return {"model_id": "", "pooling": "", "classes": {}}
246
-
247
 
248
  def _summarize_state(state: Dict) -> Dict:
249
  """Summarize counts in the classifier state.
250
-
251
  Args:
252
  state: Current classifier state.
253
-
254
  Returns:
255
  Dict with counts for display.
256
  """
257
  return {
258
  "model_id": state.get("model_id", ""),
259
- "pooling": state.get("pooling", ""),
260
  "class_counts": {k: v.get("count", 0) for k, v in state.get("classes", {}).items()},
261
  "num_classes": len(state.get("classes", {})),
262
  "total_examples": int(sum(v.get("count", 0) for v in state.get("classes", {}).values())),
263
  }
264
 
265
-
266
  @spaces.GPU(duration=_gpu_duration_gallery)
267
- def add_class(class_name: str, files: List[str], model_name: str, pooling: str, state: Dict):
268
  """Add images to a labeled class and update embeddings.
269
-
270
  Args:
271
  class_name: Target class label.
272
  files: Image file paths to embed and store.
273
  model_name: Key from MODELS.
274
- pooling: Either CLS or Mean of patch tokens.
275
  state: Current classifier state.
276
-
277
  Returns:
278
  Summary dict and the updated state.
279
  """
@@ -285,10 +243,9 @@ def add_class(class_name: str, files: List[str], model_name: str, pooling: str,
285
  return {"error": "No images uploaded for this class"}, state
286
 
287
  model_id = MODELS[model_name]
288
- if state.get("model_id") and (state["model_id"] != model_id or state.get("pooling") != pooling):
289
  state = _init_state()
290
  state["model_id"] = model_id
291
- state["pooling"] = pooling
292
 
293
  imgs = _open_images_from_paths(files)
294
  if not imgs:
@@ -296,7 +253,7 @@ def add_class(class_name: str, files: List[str], model_name: str, pooling: str,
296
 
297
  embs = []
298
  for im in imgs:
299
- e = _embed(im, model_id, pooling).astype(np.float32)
300
  embs.append(e)
301
  X = np.vstack(embs)
302
 
@@ -309,18 +266,14 @@ def add_class(class_name: str, files: List[str], model_name: str, pooling: str,
309
  state["classes"][class_name]["count"] = new.shape[0]
310
  return _summarize_state(state), state
311
 
312
-
313
  @spaces.GPU(duration=_gpu_duration_classify)
314
- def predict_class(image: Image.Image, model_name: str, pooling: str, state: Dict, top_k: int):
315
  """Predict a class using cosine to class centroids.
316
-
317
  Args:
318
  image: Query PIL image.
319
  model_name: Key from MODELS.
320
- pooling: Either CLS or Mean of patch tokens.
321
  state: Classifier state holding embeddings per class.
322
  top_k: Number of classes to report.
323
-
324
  Returns:
325
  Info dict with prediction, a label dict for display, and HTML with ranks.
326
  """
@@ -332,10 +285,10 @@ def predict_class(image: Image.Image, model_name: str, pooling: str, state: Dict
332
  return {"error": "No classes have been added yet"}, {}, None
333
 
334
  model_id = MODELS[model_name]
335
- if state.get("model_id") != model_id or state.get("pooling") != pooling:
336
- return {"error": "Model or pooling changed after building classes. Clear and rebuild."}, {}, None
337
 
338
- q = _embed(image, model_id, pooling).astype(np.float32)[None, :]
339
  qn = _normalize_rows(q)
340
 
341
  names = []
@@ -363,13 +316,10 @@ def predict_class(image: Image.Image, model_name: str, pooling: str, state: Dict
363
  ) + "</ol>"
364
  return {"top_k": top_k, "prediction": names[order[0]]}, result_dict, full_table
365
 
366
-
367
  def clear_classes(_state: Dict):
368
  """Reset the classifier state to empty.
369
-
370
  Args:
371
  _state: Previous state ignored.
372
-
373
  Returns:
374
  Fresh state and its summary.
375
  """
@@ -380,7 +330,7 @@ def clear_classes(_state: Dict):
380
  # ---------------------------
381
 
382
  with gr.Blocks() as app:
383
- gr.Markdown("# DINOv3 - Similarity, Classification")
384
 
385
  with gr.Accordion("Paper and Citation", open=False):
386
  gr.Markdown("""
@@ -391,7 +341,7 @@ with gr.Blocks() as app:
391
  ```
392
  @misc{simeoni2025dinov3,
393
  title={DINOv3},
394
- author={Oriane Siméoni and Huy V. Vo and Maximilian Seitzer and Federico Baldassarre and Maxime Oquab and Cijo Jose and Vasil Khalidov and Marc Szafraniec and Seungeun Yi and Michaël Ramamonjisoa and Francisco Massa and Daniel Haziza and Luca Wehrstedt and Jianyuan Wang and Timothée Darcet and Théo Moutakanni and Leonel Sentana and Claire Roberts and Andrea Vedaldi and Jamie Tolan and John Brandt and Camille Couprie and Julien Mairal and Hervé Jégou and Patrick Labatut and Piotr Bojanowski}},
395
  year={2025},
396
  eprint={2508.10104},
397
  archivePrefix={arXiv},
@@ -401,26 +351,17 @@ with gr.Blocks() as app:
401
  ``` """
402
  )
403
 
404
- # ------------- Similarity -------------
405
  with gr.Tab("Similarity"):
406
  gr.Markdown("Upload multiple images to compute a cosine similarity matrix and download a CSV.")
407
  files_in = gr.Files(label="Upload images", file_types=["image"], file_count="multiple", type="filepath")
408
  gallery_preview = gr.Gallery(label="Preview", columns=4, height=300)
409
  model_dd2 = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Backbone")
410
- pooling2 = gr.Radio(["CLS", "Mean of patch tokens"], value="CLS", label="Pooling")
411
  go = gr.Button("Compute cosine")
412
  table = gr.HTML(label="Cosine similarity")
413
  csv = gr.File(label="cosine_similarity_matrix.csv")
414
 
415
  def _preview(paths):
416
- """Preview images from file paths as a gallery.
417
-
418
- Args:
419
- paths: List of file paths from gr.Files.
420
-
421
- Returns:
422
- List of PIL images for the Gallery.
423
- """
424
  if not paths:
425
  return []
426
  imgs = []
@@ -432,15 +373,14 @@ with gr.Blocks() as app:
432
  return imgs
433
 
434
  files_in.change(_preview, inputs=files_in, outputs=gallery_preview)
435
- go.click(batch_similarity, [files_in, model_dd2, pooling2], [table, csv])
436
 
437
- # ------------- Image Classification -------------
438
  with gr.Tab("Image Classification"):
439
  st = gr.State(_init_state())
440
  with gr.Row():
441
  with gr.Column():
442
  model_dd3 = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Backbone")
443
- pooling3 = gr.Radio(["CLS", "Mean of patch tokens"], value="CLS", label="Pooling")
444
  gr.Markdown("Build your labeled set by adding a few images per class.")
445
  class_name = gr.Textbox(label="Class name")
446
  class_files = gr.Files(label="Upload images for this class", file_types=["image"], type="filepath", file_count="multiple")
@@ -456,7 +396,7 @@ with gr.Blocks() as app:
456
 
457
  add_btn.click(
458
  add_class,
459
- [class_name, class_files, model_dd3, pooling3, st],
460
  [state_view, st],
461
  )
462
  clear_btn.click(
@@ -466,7 +406,7 @@ with gr.Blocks() as app:
466
  )
467
  predict_btn.click(
468
  predict_class,
469
- [query_img, model_dd3, pooling3, st, topk],
470
  [gr.JSON(label="Info"), predicted, scores_html],
471
  )
472
 
 
36
 
37
  def _gpu_duration_gallery(files: List[str], *_args, **_kwargs) -> int:
38
  """Return a booking duration for a gallery job based on file count.
 
39
  Args:
40
  files: List of file paths. Used to scale the reservation window.
 
41
  Returns:
42
  Number of seconds to reserve capped at 600.
43
  """
44
  n = max(1, len(files) if files else 1)
45
+ return min(600, 35 * n + 30)
 
46
 
47
  def _gpu_duration_classify(*_args, **_kwargs) -> int:
48
  """Return a small booking duration for classification runs.
 
49
  Returns:
50
  Number of seconds to reserve for classification.
51
  """
52
+ return 90
53
 
54
  # ---------------------------
55
+ # Model loading and CLS embedding extraction
56
  # ---------------------------
57
 
58
  def _load(model_id: str) -> Tuple[AutoImageProcessor, AutoModel]:
59
  """Load processor and model then move to CUDA eval in float32.
 
60
  Args:
61
  model_id: Hugging Face model id to load.
 
62
  Returns:
63
  Tuple of processor and model on CUDA in eval mode.
64
  """
 
71
  model.to("cuda").to(torch.float32).eval()
72
  return processor, model
73
 
 
74
  def _to_cuda_batchfeature(bf):
75
  """Move a BatchFeature or dict of tensors to CUDA.
 
76
  Args:
77
  bf: Transformers BatchFeature or a dict of tensors.
 
78
  Returns:
79
  BatchFeature or dict on CUDA.
80
  """
 
82
  return bf.to("cuda")
83
  return {k: v.to("cuda") for k, v in bf.items()}
84
 
85
+ def _embed_cls(image: Image.Image, model_id: str) -> np.ndarray:
 
86
  """Extract a single-image DINOv3 embedding.
 
87
  Args:
88
  image: Input PIL image.
89
  model_id: Backbone id from MODELS.
 
 
90
  Returns:
91
  1D NumPy vector in float32.
92
  """
 
97
  with torch.inference_mode():
98
  out = model(**bf)
99
 
100
+ if getattr(out, "pooler_output", None) is not None:
101
+ emb = out.pooler_output[0]
 
 
 
102
  else:
103
+ emb = out.last_hidden_state[0, 0]
 
 
 
 
 
 
104
 
105
  return emb.float().cpu().numpy().astype(np.float32)
106
 
 
110
 
111
  def _open_images_from_paths(paths: List[str]) -> List[Image.Image]:
112
  """Open many image files as RGB PIL images.
 
113
  Args:
114
  paths: List of file paths.
 
115
  Returns:
116
  List of PIL images. Files that fail to open are skipped.
117
  """
 
124
  pass
125
  return imgs
126
 
 
127
  def _to_html_table(S: np.ndarray, names: List[str]) -> str:
128
  """Render a cosine similarity matrix as an HTML table.
 
129
  Args:
130
  S: Square matrix of cosine similarities.
131
  names: File names for header and row labels.
 
132
  Returns:
133
  HTML string with a scrollable table.
134
  """
 
150
  """
151
  return table
152
 
 
153
  def _normalize_rows(X: np.ndarray) -> np.ndarray:
154
  """Normalize rows to unit norm with safe clipping.
 
155
  Args:
156
  X: Matrix of shape N by D.
 
157
  Returns:
158
  Matrix with each row divided by its L2 norm.
159
  """
 
165
  # ---------------------------
166
 
167
  @spaces.GPU(duration=_gpu_duration_gallery)
168
+ def batch_similarity(files: List[str], model_name: str):
169
  """Compute pairwise cosine similarities for many images.
 
170
  Args:
171
  files: List of image file paths.
172
  model_name: Key from MODELS.
 
 
173
  Returns:
174
  html_table: HTML table with cosine similarities.
175
  csv_path: Path to a CSV file of the matrix.
 
184
  imgs = _open_images_from_paths(paths)
185
  embs = []
186
  for img in imgs:
187
+ e = _embed_cls(img, model_id)
188
  embs.append(e)
189
 
190
  if len(embs) < 2:
 
200
  return html, csv_path
201
 
202
  # ---------------------------
203
+ # Image Classification using DINOv3 CLS embeddings
204
  # ---------------------------
205
 
206
  def _init_state() -> Dict:
207
  """Create an empty classifier state.
 
208
  Returns:
209
+ Dict with model_id and classes.
210
  """
211
+ return {"model_id": "", "classes": {}}
 
212
 
213
  def _summarize_state(state: Dict) -> Dict:
214
  """Summarize counts in the classifier state.
 
215
  Args:
216
  state: Current classifier state.
 
217
  Returns:
218
  Dict with counts for display.
219
  """
220
  return {
221
  "model_id": state.get("model_id", ""),
 
222
  "class_counts": {k: v.get("count", 0) for k, v in state.get("classes", {}).items()},
223
  "num_classes": len(state.get("classes", {})),
224
  "total_examples": int(sum(v.get("count", 0) for v in state.get("classes", {}).values())),
225
  }
226
 
 
227
  @spaces.GPU(duration=_gpu_duration_gallery)
228
+ def add_class(class_name: str, files: List[str], model_name: str, state: Dict):
229
  """Add images to a labeled class and update embeddings.
 
230
  Args:
231
  class_name: Target class label.
232
  files: Image file paths to embed and store.
233
  model_name: Key from MODELS.
 
234
  state: Current classifier state.
 
235
  Returns:
236
  Summary dict and the updated state.
237
  """
 
243
  return {"error": "No images uploaded for this class"}, state
244
 
245
  model_id = MODELS[model_name]
246
+ if state.get("model_id") and state["model_id"] != model_id:
247
  state = _init_state()
248
  state["model_id"] = model_id
 
249
 
250
  imgs = _open_images_from_paths(files)
251
  if not imgs:
 
253
 
254
  embs = []
255
  for im in imgs:
256
+ e = _embed_cls(im, model_id).astype(np.float32)
257
  embs.append(e)
258
  X = np.vstack(embs)
259
 
 
266
  state["classes"][class_name]["count"] = new.shape[0]
267
  return _summarize_state(state), state
268
 
 
269
  @spaces.GPU(duration=_gpu_duration_classify)
270
+ def predict_class(image: Image.Image, model_name: str, state: Dict, top_k: int):
271
  """Predict a class using cosine to class centroids.
 
272
  Args:
273
  image: Query PIL image.
274
  model_name: Key from MODELS.
 
275
  state: Classifier state holding embeddings per class.
276
  top_k: Number of classes to report.
 
277
  Returns:
278
  Info dict with prediction, a label dict for display, and HTML with ranks.
279
  """
 
285
  return {"error": "No classes have been added yet"}, {}, None
286
 
287
  model_id = MODELS[model_name]
288
+ if state.get("model_id") != model_id:
289
+ return {"error": "Model changed after building classes. Clear and rebuild."}, {}, None
290
 
291
+ q = _embed_cls(image, model_id).astype(np.float32)[None, :]
292
  qn = _normalize_rows(q)
293
 
294
  names = []
 
316
  ) + "</ol>"
317
  return {"top_k": top_k, "prediction": names[order[0]]}, result_dict, full_table
318
 
 
319
  def clear_classes(_state: Dict):
320
  """Reset the classifier state to empty.
 
321
  Args:
322
  _state: Previous state ignored.
 
323
  Returns:
324
  Fresh state and its summary.
325
  """
 
330
  # ---------------------------
331
 
332
  with gr.Blocks() as app:
333
+ gr.Markdown("# DINOv3 Similarity and Classification")
334
 
335
  with gr.Accordion("Paper and Citation", open=False):
336
  gr.Markdown("""
 
341
  ```
342
  @misc{simeoni2025dinov3,
343
  title={DINOv3},
344
+ author={Oriane Siméoni and Huy V. Vo and Maximilian Seitzer and Federico Baldassarre and Maxime Oquab and Cijo Jose and Vasil Khalidov and Marc Szafraniec and Seungeun Yi and Michaël Ramamonjisoa and Francisco Massa and Daniel Haziza and Luca Wehrstedt and Jianyuan Wang and Timothée Darcet and Théo Moutakanni and Leonel Sentana and Claire Roberts and Andrea Vedaldi and Jamie Tolan and John Brandt and Camille Couprie and Julien Mairal and Hervé Jégou and Patrick Labatut and Piotr Bojanowski},
345
  year={2025},
346
  eprint={2508.10104},
347
  archivePrefix={arXiv},
 
351
  ``` """
352
  )
353
 
354
+ # Similarity
355
  with gr.Tab("Similarity"):
356
  gr.Markdown("Upload multiple images to compute a cosine similarity matrix and download a CSV.")
357
  files_in = gr.Files(label="Upload images", file_types=["image"], file_count="multiple", type="filepath")
358
  gallery_preview = gr.Gallery(label="Preview", columns=4, height=300)
359
  model_dd2 = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Backbone")
 
360
  go = gr.Button("Compute cosine")
361
  table = gr.HTML(label="Cosine similarity")
362
  csv = gr.File(label="cosine_similarity_matrix.csv")
363
 
364
  def _preview(paths):
 
 
 
 
 
 
 
 
365
  if not paths:
366
  return []
367
  imgs = []
 
373
  return imgs
374
 
375
  files_in.change(_preview, inputs=files_in, outputs=gallery_preview)
376
+ go.click(batch_similarity, [files_in, model_dd2], [table, csv])
377
 
378
+ # Image Classification
379
  with gr.Tab("Image Classification"):
380
  st = gr.State(_init_state())
381
  with gr.Row():
382
  with gr.Column():
383
  model_dd3 = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Backbone")
 
384
  gr.Markdown("Build your labeled set by adding a few images per class.")
385
  class_name = gr.Textbox(label="Class name")
386
  class_files = gr.Files(label="Upload images for this class", file_types=["image"], type="filepath", file_count="multiple")
 
396
 
397
  add_btn.click(
398
  add_class,
399
+ [class_name, class_files, model_dd3, st],
400
  [state_view, st],
401
  )
402
  clear_btn.click(
 
406
  )
407
  predict_btn.click(
408
  predict_class,
409
+ [query_img, model_dd3, st, topk],
410
  [gr.JSON(label="Info"), predicted, scores_html],
411
  )
412