Spaces:
Sleeping
Sleeping
Commit
·
52932d2
1
Parent(s):
070106a
added maskformer based object extraction
Browse files- extract_tools.py +16 -30
- utils.py +5 -1
extract_tools.py
CHANGED
@@ -124,39 +124,25 @@ def generate_bounding_box_tool(input_data:str)->str:
|
|
124 |
object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
|
125 |
return object_data
|
126 |
|
|
|
127 |
@tool
|
128 |
-
def object_extraction(
|
129 |
"Use this tool to identify the objects within the image"
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
try:
|
137 |
-
processor = BlipProcessor.from_pretrained(hf_model)
|
138 |
-
caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
|
139 |
-
except:
|
140 |
-
logging.error("unable to load the Blip model ")
|
141 |
-
|
142 |
-
logging.info("Image Caption model loaded ! ")
|
143 |
-
|
144 |
-
# unconditional image captioning
|
145 |
-
inputs = processor(image, return_tensors ='pt').to(device)
|
146 |
-
output = caption_model.generate(**inputs, max_new_tokens=50)
|
147 |
-
llm = get_groq_model()
|
148 |
-
getobject_chain = create_object_extraction_chain(llm=llm)
|
149 |
-
|
150 |
-
extracted_objects = getobject_chain.invoke({
|
151 |
-
'context': processor.decode(output[0], skip_special_tokens=True)
|
152 |
-
}).objects
|
153 |
-
|
154 |
-
print("Extracted objects : ",extracted_objects)
|
155 |
-
## clear the GPU cache
|
156 |
with torch.no_grad():
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
@tool
|
162 |
def get_image_quality(image_path:str)->str:
|
|
|
124 |
object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
|
125 |
return object_data
|
126 |
|
127 |
+
|
128 |
@tool
|
129 |
+
def object_extraction(image_path:str)->str:
|
130 |
"Use this tool to identify the objects within the image"
|
131 |
+
objects = []
|
132 |
+
maskformer_model.to(device)
|
133 |
+
image = cv2.imread(image_path)
|
134 |
+
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
|
135 |
+
inputs = maskformer_processor(image, return_tensors="pt")
|
136 |
+
inputs.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
with torch.no_grad():
|
138 |
+
outputs = maskformer_model(**inputs)
|
139 |
+
prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.shape[:2]])[0]
|
140 |
+
segments_info = prediction['segments_info']
|
141 |
+
for segment in segments_info:
|
142 |
+
segment_label_id = segment['label_id']
|
143 |
+
segment_label = maskformer_model.config.id2label[segment_label_id]
|
144 |
+
objects.append(segment_label)
|
145 |
+
return "Detected objects are: "+ " ".join( objects)
|
146 |
|
147 |
@tool
|
148 |
def get_image_quality(image_path:str)->str:
|
utils.py
CHANGED
@@ -50,4 +50,8 @@ def draw_bboxes(rgb_frame,boxes,labels,color=None,line_thickness=3):
|
|
50 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
51 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
52 |
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 0, 255], thickness=tf, lineType=cv2.LINE_AA)
|
53 |
-
return rgb_frame_copy
|
|
|
|
|
|
|
|
|
|
50 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
51 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
52 |
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 0, 255], thickness=tf, lineType=cv2.LINE_AA)
|
53 |
+
return rgb_frame_copy
|
54 |
+
|
55 |
+
def object_extraction_using_maskformer(image_path):
|
56 |
+
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
57 |
+
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|