Samarth991 commited on
Commit
52932d2
·
1 Parent(s): 070106a

added maskformer based object extraction

Browse files
Files changed (2) hide show
  1. extract_tools.py +16 -30
  2. utils.py +5 -1
extract_tools.py CHANGED
@@ -124,39 +124,25 @@ def generate_bounding_box_tool(input_data:str)->str:
124
  object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
125
  return object_data
126
 
 
127
  @tool
128
- def object_extraction(img_path:str)->str:
129
  "Use this tool to identify the objects within the image"
130
-
131
- hf_model = "Salesforce/blip-image-captioning-base"
132
- if img_path.startswith('https'):
133
- image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
134
- else:
135
- image = Image.open(img_path).convert('RGB')
136
- try:
137
- processor = BlipProcessor.from_pretrained(hf_model)
138
- caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
139
- except:
140
- logging.error("unable to load the Blip model ")
141
-
142
- logging.info("Image Caption model loaded ! ")
143
-
144
- # unconditional image captioning
145
- inputs = processor(image, return_tensors ='pt').to(device)
146
- output = caption_model.generate(**inputs, max_new_tokens=50)
147
- llm = get_groq_model()
148
- getobject_chain = create_object_extraction_chain(llm=llm)
149
-
150
- extracted_objects = getobject_chain.invoke({
151
- 'context': processor.decode(output[0], skip_special_tokens=True)
152
- }).objects
153
-
154
- print("Extracted objects : ",extracted_objects)
155
- ## clear the GPU cache
156
  with torch.no_grad():
157
- torch.cuda.empty_cache()
158
-
159
- return extracted_objects.split(',')
 
 
 
 
 
160
 
161
  @tool
162
  def get_image_quality(image_path:str)->str:
 
124
  object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
125
  return object_data
126
 
127
+
128
  @tool
129
+ def object_extraction(image_path:str)->str:
130
  "Use this tool to identify the objects within the image"
131
+ objects = []
132
+ maskformer_model.to(device)
133
+ image = cv2.imread(image_path)
134
+ image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
135
+ inputs = maskformer_processor(image, return_tensors="pt")
136
+ inputs.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  with torch.no_grad():
138
+ outputs = maskformer_model(**inputs)
139
+ prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.shape[:2]])[0]
140
+ segments_info = prediction['segments_info']
141
+ for segment in segments_info:
142
+ segment_label_id = segment['label_id']
143
+ segment_label = maskformer_model.config.id2label[segment_label_id]
144
+ objects.append(segment_label)
145
+ return "Detected objects are: "+ " ".join( objects)
146
 
147
  @tool
148
  def get_image_quality(image_path:str)->str:
utils.py CHANGED
@@ -50,4 +50,8 @@ def draw_bboxes(rgb_frame,boxes,labels,color=None,line_thickness=3):
50
  t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
51
  c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
52
  cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 0, 255], thickness=tf, lineType=cv2.LINE_AA)
53
- return rgb_frame_copy
 
 
 
 
 
50
  t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
51
  c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
52
  cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 0, 255], thickness=tf, lineType=cv2.LINE_AA)
53
+ return rgb_frame_copy
54
+
55
+ def object_extraction_using_maskformer(image_path):
56
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
57
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")