vinhtruong3 commited on
Commit
f18cbce
·
verified ·
1 Parent(s): 8e93e59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -87
app.py CHANGED
@@ -16,17 +16,16 @@ print(f"Using device: {device}")
16
  model_configs = {
17
  'gokaygokay/Florence-2-Flux': "<DESCRIPTION>",
18
  'gokaygokay/Florence-2-Flux-Large': "<DESCRIPTION>",
19
- 'yayayaaa/florence-2-large-ft-moredetailed': "<MORE_DETAILED_CAPTION>"
20
- # Temporarily removed MiaoshouAI model due to compatibility issues
21
- # 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "<MORE_DETAILED_CAPTION>"
22
  }
23
 
24
  # Define a description for each model to be shown in UI
25
  model_descriptions = {
26
  'gokaygokay/Florence-2-Flux': "Faster version with good quality captions",
27
  'gokaygokay/Florence-2-Flux-Large': "Provides detailed captions with better image understanding",
28
- 'yayayaaa/florence-2-large-ft-moredetailed': "Fine-tuned specifically for more detailed captions"
29
- # 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "Memory efficient model with high quality detailed captions"
30
  }
31
 
32
  # Load a single model to start with
@@ -56,6 +55,7 @@ title = """<h1 align="center">Florence-2 Caption Dataset Creator</h1>
56
  <a href="https://huggingface.co/gokaygokay/Florence-2-Flux-Large" target="_blank">[Florence-2 Flux Large]</a>
57
  <a href="https://huggingface.co/gokaygokay/Florence-2-Flux" target="_blank">[Florence-2 Flux Base]</a>
58
  <a href="https://huggingface.co/yayayaaa/florence-2-large-ft-moredetailed" target="_blank">[Florence-2 More Detailed]</a>
 
59
  </center></p>"""
60
 
61
  # Function to clean caption text
@@ -100,6 +100,50 @@ def load_model(selected_model_name):
100
 
101
  return "Model loaded successfully"
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # Function to generate a caption for a single image
104
  @spaces.GPU
105
  def generate_caption(image, selected_model_name):
@@ -124,47 +168,54 @@ def generate_caption(image, selected_model_name):
124
  if image.mode != "RGB":
125
  image = image.convert("RGB")
126
 
127
- # Create an appropriate prompt based on the model
128
- prompt = task_prompt
129
-
130
  try:
131
- # Process the image
132
- inputs = processor(text=prompt, images=image, return_tensors="pt")
133
-
134
- # Move inputs to the same device as the model
135
- for key in inputs:
136
- if isinstance(inputs[key], torch.Tensor):
137
- inputs[key] = inputs[key].to(device)
138
-
139
- # Generate the caption
140
- with torch.no_grad():
141
- generated_ids = model.generate(
142
- **inputs,
143
- max_new_tokens=512, # Reduced for better memory usage
144
- num_beams=3,
145
- repetition_penalty=1.10,
146
- )
147
-
148
- # Decode the generated text
149
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
150
-
151
- # Handle post-processing for different models
152
- if task_prompt == "<DESCRIPTION>":
153
- # Use the post processing for Florence-2-Flux models
154
- try:
155
- decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
156
- parsed_answer = processor.post_process_generation(
157
- decoded_text,
158
- task=task_prompt,
159
- image_size=(image.width, image.height)
160
- )
161
- caption = parsed_answer[task_prompt]
162
- except Exception as e:
163
- print(f"Error in post processing: {str(e)}")
164
- caption = generated_text # Fallback to direct output
165
  else:
166
- # For other models, use the generated text directly
167
- caption = generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # Clean the caption to remove padding tokens
170
  clean_text = clean_caption(caption)
@@ -211,51 +262,59 @@ def process_images(images, selected_model_name, add_trigger=True, trigger_word="
211
  results.append(f"⚠️ Skipped {base_name}: Unsupported format (only jpg, jpeg, png supported)")
212
  continue
213
 
214
- # Generate caption for this specific image
 
215
  image = Image.open(img_path)
216
  if image.mode != "RGB":
217
  image = image.convert("RGB")
218
 
219
- # Use the task prompt for the selected model
220
- prompt = task_prompt
221
-
222
- # Process the image
223
- inputs = processor(text=prompt, images=image, return_tensors="pt")
224
-
225
- # Move inputs to the same device as the model
226
- for key in inputs:
227
- if isinstance(inputs[key], torch.Tensor):
228
- inputs[key] = inputs[key].to(device)
229
-
230
- # Generate the caption
231
- with torch.no_grad():
232
- generated_ids = model.generate(
233
- **inputs,
234
- max_new_tokens=512, # Reduced for better memory usage
235
- num_beams=3,
236
- repetition_penalty=1.10,
237
- )
238
-
239
- # Decode the generated text
240
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
241
-
242
- # Handle post-processing for different models
243
- if task_prompt == "<DESCRIPTION>":
244
- # Use the post processing for Florence-2-Flux models
245
- try:
246
- decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
247
- parsed_answer = processor.post_process_generation(
248
- decoded_text,
249
- task=task_prompt,
250
- image_size=(image.width, image.height)
251
- )
252
- caption = parsed_answer[task_prompt]
253
- except Exception as e:
254
- print(f"Error in post processing: {str(e)}")
255
- caption = generated_text # Fallback to direct output
256
  else:
257
- # For other models, use the generated text directly
258
- caption = generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # Clean caption and add trigger if needed
261
  caption = clean_caption(caption)
@@ -377,10 +436,13 @@ with gr.Blocks() as demo:
377
 
378
  gr.Markdown(model_md)
379
 
380
- # Add note about MiaoshouAI model
381
  gr.Markdown("""
382
- ### Note:
383
- The MiaoshouAI/Florence-2-large-PromptGen-v2.0 model has been temporarily removed due to compatibility issues with the current setup. The available models still provide excellent captioning capabilities.
 
 
 
384
 
385
  Supported image formats: JPG, JPEG, PNG
386
  """)
 
16
  model_configs = {
17
  'gokaygokay/Florence-2-Flux': "<DESCRIPTION>",
18
  'gokaygokay/Florence-2-Flux-Large': "<DESCRIPTION>",
19
+ 'yayayaaa/florence-2-large-ft-moredetailed': "<MORE_DETAILED_CAPTION>",
20
+ 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "<MORE_DETAILED_CAPTION>"
 
21
  }
22
 
23
  # Define a description for each model to be shown in UI
24
  model_descriptions = {
25
  'gokaygokay/Florence-2-Flux': "Faster version with good quality captions",
26
  'gokaygokay/Florence-2-Flux-Large': "Provides detailed captions with better image understanding",
27
+ 'yayayaaa/florence-2-large-ft-moredetailed': "Fine-tuned specifically for more detailed captions",
28
+ 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "Memory efficient model with high quality detailed captions"
29
  }
30
 
31
  # Load a single model to start with
 
55
  <a href="https://huggingface.co/gokaygokay/Florence-2-Flux-Large" target="_blank">[Florence-2 Flux Large]</a>
56
  <a href="https://huggingface.co/gokaygokay/Florence-2-Flux" target="_blank">[Florence-2 Flux Base]</a>
57
  <a href="https://huggingface.co/yayayaaa/florence-2-large-ft-moredetailed" target="_blank">[Florence-2 More Detailed]</a>
58
+ <a href="https://huggingface.co/MiaoshouAI/Florence-2-large-PromptGen-v2.0" target="_blank">[MiaoshouAI PromptGen v2.0]</a>
59
  </center></p>"""
60
 
61
  # Function to clean caption text
 
100
 
101
  return "Model loaded successfully"
102
 
103
+ # Special function for MiaoshouAI model
104
+ def generate_miaoshou_caption(image):
105
+ """Special handling for MiaoshouAI model"""
106
+ # Create inputs for MiaoshouAI model
107
+ inputs = processor(
108
+ text=task_prompt,
109
+ images=image,
110
+ return_tensors="pt"
111
+ )
112
+
113
+ # Move inputs to device
114
+ for key in inputs:
115
+ if isinstance(inputs[key], torch.Tensor):
116
+ inputs[key] = inputs[key].to(device)
117
+
118
+ # Generate using only input_ids and pixel_values
119
+ generated_ids = model.generate(
120
+ input_ids=inputs["input_ids"],
121
+ pixel_values=inputs["pixel_values"],
122
+ max_new_tokens=512,
123
+ do_sample=False,
124
+ num_beams=3
125
+ )
126
+
127
+ # Decode the generated text
128
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
129
+
130
+ # Use the model's post-processing
131
+ try:
132
+ parsed_answer = processor.post_process_generation(
133
+ generated_text,
134
+ task=task_prompt,
135
+ image_size=(image.width, image.height)
136
+ )
137
+ # Get the generated text from parsed answer
138
+ if isinstance(parsed_answer, dict) and task_prompt in parsed_answer:
139
+ return parsed_answer[task_prompt]
140
+ else:
141
+ return str(parsed_answer)
142
+ except Exception as e:
143
+ print(f"Post-processing error: {str(e)}")
144
+ # Fallback to regular decoding if post-processing fails
145
+ return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
146
+
147
  # Function to generate a caption for a single image
148
  @spaces.GPU
149
  def generate_caption(image, selected_model_name):
 
168
  if image.mode != "RGB":
169
  image = image.convert("RGB")
170
 
 
 
 
171
  try:
172
+ # Special handling for MiaoshouAI model
173
+ if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0':
174
+ caption = generate_miaoshou_caption(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  else:
176
+ # Regular processing for other models
177
+ # Create an appropriate prompt based on the model
178
+ prompt = task_prompt
179
+ if prompt == "<DESCRIPTION>":
180
+ prompt = prompt + "Describe this image in great detail."
181
+
182
+ # Process the image
183
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
184
+
185
+ # Move inputs to the same device as the model
186
+ for key in inputs:
187
+ if isinstance(inputs[key], torch.Tensor):
188
+ inputs[key] = inputs[key].to(device)
189
+
190
+ # Generate the caption
191
+ with torch.no_grad():
192
+ generated_ids = model.generate(
193
+ **inputs,
194
+ max_new_tokens=512,
195
+ num_beams=3,
196
+ repetition_penalty=1.10,
197
+ )
198
+
199
+ # Decode the generated text
200
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
201
+
202
+ # Handle post-processing for different models
203
+ if task_prompt == "<DESCRIPTION>":
204
+ # Use the post processing for Florence-2-Flux models
205
+ try:
206
+ decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
207
+ parsed_answer = processor.post_process_generation(
208
+ decoded_text,
209
+ task=task_prompt,
210
+ image_size=(image.width, image.height)
211
+ )
212
+ caption = parsed_answer[task_prompt]
213
+ except Exception as e:
214
+ print(f"Error in post processing: {str(e)}")
215
+ caption = generated_text # Fallback to direct output
216
+ else:
217
+ # For other models, use the generated text directly
218
+ caption = generated_text
219
 
220
  # Clean the caption to remove padding tokens
221
  clean_text = clean_caption(caption)
 
262
  results.append(f"⚠️ Skipped {base_name}: Unsupported format (only jpg, jpeg, png supported)")
263
  continue
264
 
265
+ # Generate caption
266
+ # Open the image once
267
  image = Image.open(img_path)
268
  if image.mode != "RGB":
269
  image = image.convert("RGB")
270
 
271
+ # Use the same caption generation logic as in generate_caption
272
+ if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0':
273
+ caption = generate_miaoshou_caption(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  else:
275
+ # Regular processing for other models
276
+ # Create an appropriate prompt based on the model
277
+ prompt = task_prompt
278
+ if prompt == "<DESCRIPTION>":
279
+ prompt = prompt + "Describe this image in great detail."
280
+
281
+ # Process the image
282
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
283
+
284
+ # Move inputs to the same device as the model
285
+ for key in inputs:
286
+ if isinstance(inputs[key], torch.Tensor):
287
+ inputs[key] = inputs[key].to(device)
288
+
289
+ # Generate the caption
290
+ with torch.no_grad():
291
+ generated_ids = model.generate(
292
+ **inputs,
293
+ max_new_tokens=512,
294
+ num_beams=3,
295
+ repetition_penalty=1.10,
296
+ )
297
+
298
+ # Decode the generated text
299
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
300
+
301
+ # Handle post-processing for different models
302
+ if task_prompt == "<DESCRIPTION>":
303
+ # Use the post processing for Florence-2-Flux models
304
+ try:
305
+ decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
306
+ parsed_answer = processor.post_process_generation(
307
+ decoded_text,
308
+ task=task_prompt,
309
+ image_size=(image.width, image.height)
310
+ )
311
+ caption = parsed_answer[task_prompt]
312
+ except Exception as e:
313
+ print(f"Error in post processing: {str(e)}")
314
+ caption = generated_text # Fallback to direct output
315
+ else:
316
+ # For other models, use the generated text directly
317
+ caption = generated_text
318
 
319
  # Clean caption and add trigger if needed
320
  caption = clean_caption(caption)
 
436
 
437
  gr.Markdown(model_md)
438
 
439
+ # Add special note for MiaoshouAI model
440
  gr.Markdown("""
441
+ ### MiaoshouAI/Florence-2-large-PromptGen-v2.0 Features
442
+ - Improved caption quality for detailed captions
443
+ - Memory efficient (requires only ~1GB VRAM)
444
+ - Fast generation while maintaining high quality
445
+ - Supports multiple caption formats including detailed captions, tags, and analysis
446
 
447
  Supported image formats: JPG, JPEG, PNG
448
  """)