Samarth991 commited on
Commit
d766b17
·
1 Parent(s): b60841d

added image detection code to display predicted bboxes

Browse files
Files changed (2) hide show
  1. tool_utils/yolo_world.py +53 -0
  2. utils.py +28 -1
tool_utils/yolo_world.py CHANGED
@@ -1,7 +1,10 @@
1
  import os
2
  import logging
 
3
  import numpy as np
4
  from typing import List
 
 
5
  from ultralytics import YOLOWorld
6
 
7
  class YoloWorld:
@@ -27,4 +30,54 @@ class YoloWorld:
27
  }
28
  object_details.append(object_data)
29
  return object_details
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
1
  import os
2
  import logging
3
+ import cv2
4
  import numpy as np
5
  from typing import List
6
+ import torch
7
+ import random
8
  from ultralytics import YOLOWorld
9
 
10
  class YoloWorld:
 
30
  }
31
  object_details.append(object_data)
32
  return object_details
33
+
34
+ @staticmethod
35
+ def draw_bboxes(rgb_frame,boxes,labels,color=None,line_thickness=3):
36
+ rgb_frame = cv2.imread(rgb_frame)
37
+ rgb_frame = cv2.cvtColor(rgb_frame,cv2.COLOR_BGR2RGB)
38
+
39
+ tl = line_thickness or round(0.002 * (rgb_frame.shape[0] + rgb_frame.shape[1]) / 2) + 1 # line/font thickness
40
+ rgb_frame_copy = rgb_frame.copy()
41
+ if color is None :
42
+ color = color or [random.randint(0, 255) for _ in range(3)]
43
+ for box,label in zip(boxes,labels):
44
+ if box.type() == 'torch.IntTensor':
45
+ box = box.numpy()
46
+ # extract coordinates
47
+ x1,y1,x2,y2 = box
48
+ c1,c2 = (x1,y1),(x2,y2)
49
+ # Draw rectangle
50
+ cv2.rectangle(rgb_frame_copy, c1,c2, color, thickness=tl, lineType=cv2.LINE_AA)
51
+
52
+ tf = max(tl - 1, 1) # font thickness
53
+ # label = label2id[int(label.numpy())]
54
+ t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
55
+ c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
56
+ cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
57
+ return rgb_frame_copy
58
+
59
+ def run_yolo_infer(self,image_path:str,object_prompts:List):
60
+ self.model.set_classes(object_prompts)
61
+ results = self.model.predict(image_path)
62
+ processed_predictions = []
63
+ bounding_boxes = []
64
+ labels = []
65
+ scores = []
66
+ for result in results:
67
+ for i,box in enumerate(result.boxes):
68
+ x1, y1, x2, y2 = np.array(box.xyxy.cpu(), dtype=np.int32).squeeze()
69
+ bounding_boxes.append([x1,y1,x2,y2])
70
+ labels.append(int(box.cls.cpu()))
71
+ scores.append(round(float(box.conf.cpu()),2))
72
+
73
+ processed_predictions.append(dict(boxes= torch.tensor(bounding_boxes),
74
+ labels= torch.IntTensor(labels),
75
+ scores=torch.tensor(scores))
76
+ )
77
+ detected_image = self.draw_bboxes(rgb_frame=image_path,
78
+ boxes=processed_predictions[0]['boxes'],
79
+ labels=processed_predictions[0]['labels']
80
+ )
81
+ cv2.imwrite('final_mask.jpg',detected_image)
82
+ return "Predicted image : final_mask.jpg . Details :{}".format(processed_predictions[0])
83
 
utils.py CHANGED
@@ -3,6 +3,8 @@ import matplotlib.pyplot as plt
3
  import matplotlib.patches as mpatches
4
  from matplotlib import cm
5
  import torch
 
 
6
 
7
  def draw_panoptic_segmentation(model,segmentation, segments_info):
8
  # get the used color map
@@ -23,4 +25,29 @@ def draw_panoptic_segmentation(model,segmentation, segments_info):
23
 
24
  # ax.legend(handles=handles)
25
  fig.savefig('final_mask.png')
26
- return 'final_mask.png'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import matplotlib.patches as mpatches
4
  from matplotlib import cm
5
  import torch
6
+ import cv2
7
+ import random
8
 
9
  def draw_panoptic_segmentation(model,segmentation, segments_info):
10
  # get the used color map
 
25
 
26
  # ax.legend(handles=handles)
27
  fig.savefig('final_mask.png')
28
+ return 'final_mask.png'
29
+
30
+
31
+ def draw_bboxes(rgb_frame,boxes,labels,color=None,line_thickness=3):
32
+ rgb_frame = cv2.imread(rgb_frame)
33
+ rgb_frame = cv2.cvtColor(rgb_frame,cv2.COLOR_BGR2RGB)
34
+
35
+ tl = line_thickness or round(0.002 * (rgb_frame.shape[0] + rgb_frame.shape[1]) / 2) + 1 # line/font thickness
36
+ rgb_frame_copy = rgb_frame.copy()
37
+ if color is None :
38
+ color = color or [random.randint(0, 255) for _ in range(3)]
39
+ for box,label in zip(boxes,labels):
40
+ if box.type() == 'torch.IntTensor':
41
+ box = box.numpy()
42
+ # extract coordinates
43
+ x1,y1,x2,y2 = box
44
+ c1,c2 = (x1,y1),(x2,y2)
45
+ # Draw rectangle
46
+ cv2.rectangle(rgb_frame_copy, c1,c2, color, thickness=tl, lineType=cv2.LINE_AA)
47
+
48
+ tf = max(tl - 1, 1) # font thickness
49
+ # label = label2id[int(label.numpy())]
50
+ t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
51
+ c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
52
+ cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
53
+ return rgb_frame_copy