Zengyf-CVer commited on
Commit
a047b00
·
1 Parent(s): 4326e02

add inference size

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -74,6 +74,9 @@ def parse_args(known=False):
74
  type=str,
75
  help="cuda or cpu, hugging face only cpu",
76
  )
 
 
 
77
 
78
  args = parser.parse_known_args()[0] if known else parser.parse_args()
79
  return args
@@ -115,7 +118,7 @@ def export_json(results, model, img_size):
115
 
116
 
117
  # YOLOv5图片检测函数
118
- def yolo_det(img, device, model_name, conf, iou, label_opt, model_cls):
119
 
120
  global model, model_name_tmp, device_tmp
121
 
@@ -133,7 +136,7 @@ def yolo_det(img, device, model_name, conf, iou, label_opt, model_cls):
133
  model.max_det = 1000 # 最大检测框数
134
  model.classes = model_cls # 模型类别
135
 
136
- results = model(img) # 检测
137
  results.render(labels=label_opt) # 渲染
138
 
139
  det_img = Image.fromarray(results.imgs[0]) # 检测图片
@@ -178,6 +181,7 @@ def main(args):
178
  model_cfg = args.model_cfg
179
  cls_name = args.cls_name
180
  device = args.device
 
181
 
182
  # 模型加载
183
  model = model_loading(model_name, device)
@@ -193,6 +197,9 @@ def main(args):
193
  inputs_model = gr.inputs.Dropdown(
194
  choices=model_names, default=model_name, type="value", label="模型"
195
  )
 
 
 
196
  input_conf = gr.inputs.Slider(
197
  0, 1, step=slider_step, default=nms_conf, label="置信度阈值"
198
  )
@@ -209,6 +216,7 @@ def main(args):
209
  inputs_img, # 输入图片
210
  device, # 设备
211
  inputs_model, # 模型
 
212
  input_conf, # 置信度阈值
213
  inputs_iou, # IoU阈值
214
  inputs_label, # 标签显示
@@ -229,6 +237,7 @@ def main(args):
229
  "./img_example/bus.jpg",
230
  "cpu",
231
  "yolov5s",
 
232
  0.6,
233
  0.5,
234
  True,
@@ -238,6 +247,7 @@ def main(args):
238
  "./img_example/Millenial-at-work.jpg",
239
  "cpu",
240
  "yolov5l",
 
241
  0.5,
242
  0.45,
243
  True,
@@ -247,6 +257,7 @@ def main(args):
247
  "./img_example/zidane.jpg",
248
  "cpu",
249
  "yolov5m",
 
250
  0.25,
251
  0.5,
252
  False,
 
74
  type=str,
75
  help="cuda or cpu, hugging face only cpu",
76
  )
77
+ parser.add_argument(
78
+ "--inference_size", "-isz", default=640, type=int, help="model inference size"
79
+ )
80
 
81
  args = parser.parse_known_args()[0] if known else parser.parse_args()
82
  return args
 
118
 
119
 
120
  # YOLOv5图片检测函数
121
+ def yolo_det(img, device, model_name, inference_size, conf, iou, label_opt, model_cls):
122
 
123
  global model, model_name_tmp, device_tmp
124
 
 
136
  model.max_det = 1000 # 最大检测框数
137
  model.classes = model_cls # 模型类别
138
 
139
+ results = model(img, size=inference_size) # 检测
140
  results.render(labels=label_opt) # 渲染
141
 
142
  det_img = Image.fromarray(results.imgs[0]) # 检测图片
 
181
  model_cfg = args.model_cfg
182
  cls_name = args.cls_name
183
  device = args.device
184
+ inference_size = args.inference_size
185
 
186
  # 模型加载
187
  model = model_loading(model_name, device)
 
197
  inputs_model = gr.inputs.Dropdown(
198
  choices=model_names, default=model_name, type="value", label="模型"
199
  )
200
+ inputs_size = gr.inputs.Radio(
201
+ choices=[320, 640], default=inference_size, label="推理尺寸"
202
+ )
203
  input_conf = gr.inputs.Slider(
204
  0, 1, step=slider_step, default=nms_conf, label="置信度阈值"
205
  )
 
216
  inputs_img, # 输入图片
217
  device, # 设备
218
  inputs_model, # 模型
219
+ inputs_size, # 推理尺寸
220
  input_conf, # 置信度阈值
221
  inputs_iou, # IoU阈值
222
  inputs_label, # 标签显示
 
237
  "./img_example/bus.jpg",
238
  "cpu",
239
  "yolov5s",
240
+ 640,
241
  0.6,
242
  0.5,
243
  True,
 
247
  "./img_example/Millenial-at-work.jpg",
248
  "cpu",
249
  "yolov5l",
250
+ 320,
251
  0.5,
252
  0.45,
253
  True,
 
257
  "./img_example/zidane.jpg",
258
  "cpu",
259
  "yolov5m",
260
+ 640,
261
  0.25,
262
  0.5,
263
  False,