Samarth991 commited on
Commit
428e149
·
1 Parent(s): d766b17

added image detection code to display predicted bboxes

Browse files
Files changed (1) hide show
  1. extract_tools.py +10 -9
extract_tools.py CHANGED
@@ -61,9 +61,10 @@ def panoptic_image_segemntation(image_path:str)->str:
61
  labels = []
62
  for segment in prediction['segments_info']:
63
  label_names = maskformer_model.config.id2label[segment['label_id']]
64
- print(label_names)
65
  labels.append(label_names)
66
- return 'Panoptic Segmentation image {} created with labels {} '.format(save_mask_path,labels)
 
67
 
68
  @tool
69
  def image_description(img_path:str)->str:
@@ -88,16 +89,16 @@ def image_description(img_path:str)->str:
88
  output = caption_model.generate(**inputs, max_new_tokens=50)
89
  caption = processor.decode(output[0], skip_special_tokens=True)
90
 
91
- # conditional image captioning
92
- obj_text = "Total number of objects in image "
93
- inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
94
- out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
95
- object_caption = processor.decode(out_2[0], skip_special_tokens=True)
96
 
97
  ## clear the GPU cache
98
  with torch.no_grad():
99
  torch.cuda.empty_cache()
100
- text = caption + " ."+ object_caption+" ."
101
  return text
102
 
103
 
@@ -120,7 +121,7 @@ def generate_bounding_box_tool(input_data:str)->str:
120
  data = input_data.split(",")
121
  image_path = data[0]
122
  object_prompts = data[1:]
123
- object_data = yolo_world_model.run_inference(image_path,object_prompts)
124
  return object_data
125
 
126
  @tool
 
61
  labels = []
62
  for segment in prediction['segments_info']:
63
  label_names = maskformer_model.config.id2label[segment['label_id']]
64
+
65
  labels.append(label_names)
66
+ labels = " ".join([label_name for label_name in labels])
67
+ return 'Panoptic Segmentation image {} Found labels {} in the image '.format(save_mask_path,labels)
68
 
69
  @tool
70
  def image_description(img_path:str)->str:
 
89
  output = caption_model.generate(**inputs, max_new_tokens=50)
90
  caption = processor.decode(output[0], skip_special_tokens=True)
91
 
92
+ # # conditional image captioning
93
+ # obj_text = "Total number of objects in image "
94
+ # inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
95
+ # out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
96
+ # object_caption = processor.decode(out_2[0], skip_special_tokens=True)
97
 
98
  ## clear the GPU cache
99
  with torch.no_grad():
100
  torch.cuda.empty_cache()
101
+ text = caption + " ."
102
  return text
103
 
104
 
 
121
  data = input_data.split(",")
122
  image_path = data[0]
123
  object_prompts = data[1:]
124
+ object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
125
  return object_data
126
 
127
  @tool