Spaces:
Sleeping
Sleeping
Commit
·
428e149
1
Parent(s):
d766b17
added image detection code to display predicted bboxes
Browse files- extract_tools.py +10 -9
extract_tools.py
CHANGED
@@ -61,9 +61,10 @@ def panoptic_image_segemntation(image_path:str)->str:
|
|
61 |
labels = []
|
62 |
for segment in prediction['segments_info']:
|
63 |
label_names = maskformer_model.config.id2label[segment['label_id']]
|
64 |
-
|
65 |
labels.append(label_names)
|
66 |
-
|
|
|
67 |
|
68 |
@tool
|
69 |
def image_description(img_path:str)->str:
|
@@ -88,16 +89,16 @@ def image_description(img_path:str)->str:
|
|
88 |
output = caption_model.generate(**inputs, max_new_tokens=50)
|
89 |
caption = processor.decode(output[0], skip_special_tokens=True)
|
90 |
|
91 |
-
# conditional image captioning
|
92 |
-
obj_text = "Total number of objects in image "
|
93 |
-
inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
|
94 |
-
out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
|
95 |
-
object_caption = processor.decode(out_2[0], skip_special_tokens=True)
|
96 |
|
97 |
## clear the GPU cache
|
98 |
with torch.no_grad():
|
99 |
torch.cuda.empty_cache()
|
100 |
-
text = caption + " ."
|
101 |
return text
|
102 |
|
103 |
|
@@ -120,7 +121,7 @@ def generate_bounding_box_tool(input_data:str)->str:
|
|
120 |
data = input_data.split(",")
|
121 |
image_path = data[0]
|
122 |
object_prompts = data[1:]
|
123 |
-
object_data = yolo_world_model.
|
124 |
return object_data
|
125 |
|
126 |
@tool
|
|
|
61 |
labels = []
|
62 |
for segment in prediction['segments_info']:
|
63 |
label_names = maskformer_model.config.id2label[segment['label_id']]
|
64 |
+
|
65 |
labels.append(label_names)
|
66 |
+
labels = " ".join([label_name for label_name in labels])
|
67 |
+
return 'Panoptic Segmentation image {} Found labels {} in the image '.format(save_mask_path,labels)
|
68 |
|
69 |
@tool
|
70 |
def image_description(img_path:str)->str:
|
|
|
89 |
output = caption_model.generate(**inputs, max_new_tokens=50)
|
90 |
caption = processor.decode(output[0], skip_special_tokens=True)
|
91 |
|
92 |
+
# # conditional image captioning
|
93 |
+
# obj_text = "Total number of objects in image "
|
94 |
+
# inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
|
95 |
+
# out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
|
96 |
+
# object_caption = processor.decode(out_2[0], skip_special_tokens=True)
|
97 |
|
98 |
## clear the GPU cache
|
99 |
with torch.no_grad():
|
100 |
torch.cuda.empty_cache()
|
101 |
+
text = caption + " ."
|
102 |
return text
|
103 |
|
104 |
|
|
|
121 |
data = input_data.split(",")
|
122 |
image_path = data[0]
|
123 |
object_prompts = data[1:]
|
124 |
+
object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
|
125 |
return object_data
|
126 |
|
127 |
@tool
|