updates
This commit is contained in:
		
							parent
							
								
									935bbfcc2b
								
							
						
					
					
						commit
						d92b75aec8
					
				|  | @ -86,7 +86,7 @@ def detect(save_txt=False, save_img=False): | ||||||
|             pred = pred.float() |             pred = pred.float() | ||||||
| 
 | 
 | ||||||
|         # Apply NMS |         # Apply NMS | ||||||
|         pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres) |         pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes) | ||||||
| 
 | 
 | ||||||
|         # Apply Classifier |         # Apply Classifier | ||||||
|         if classify: |         if classify: | ||||||
|  | @ -110,9 +110,6 @@ def detect(save_txt=False, save_img=False): | ||||||
|                     n = (det[:, -1] == c).sum()  # detections per class |                     n = (det[:, -1] == c).sum()  # detections per class | ||||||
|                     s += '%g %ss, ' % (n, names[int(c)])  # add to string |                     s += '%g %ss, ' % (n, names[int(c)])  # add to string | ||||||
| 
 | 
 | ||||||
|                 # Print time (inference + NMS) |  | ||||||
|                 print('%sDone. (%.3fs)' % (s, time.time() - t)) |  | ||||||
| 
 |  | ||||||
|                 # Write results |                 # Write results | ||||||
|                 for *xyxy, conf, cls in det: |                 for *xyxy, conf, cls in det: | ||||||
|                     if save_txt:  # Write to file |                     if save_txt:  # Write to file | ||||||
|  | @ -123,6 +120,9 @@ def detect(save_txt=False, save_img=False): | ||||||
|                         label = '%s %.2f' % (names[int(cls)], conf) |                         label = '%s %.2f' % (names[int(cls)], conf) | ||||||
|                         plot_one_box(xyxy, im0, label=label, color=colors[int(cls)]) |                         plot_one_box(xyxy, im0, label=label, color=colors[int(cls)]) | ||||||
| 
 | 
 | ||||||
|  |             # Print time (inference + NMS) | ||||||
|  |             print('%sDone. (%.3fs)' % (s, time.time() - t)) | ||||||
|  | 
 | ||||||
|             # Stream results |             # Stream results | ||||||
|             if view_img: |             if view_img: | ||||||
|                 cv2.imshow(p, im0) |                 cv2.imshow(p, im0) | ||||||
|  | @ -167,6 +167,7 @@ if __name__ == '__main__': | ||||||
|     parser.add_argument('--half', action='store_true', help='half precision FP16 inference') |     parser.add_argument('--half', action='store_true', help='half precision FP16 inference') | ||||||
|     parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') |     parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu') | ||||||
|     parser.add_argument('--view-img', action='store_true', help='display results') |     parser.add_argument('--view-img', action='store_true', help='display results') | ||||||
|  |     parser.add_argument('--classes', nargs='+', type=int, help='filter by class') | ||||||
|     opt = parser.parse_args() |     opt = parser.parse_args() | ||||||
|     print(opt) |     print(opt) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -498,7 +498,7 @@ def build_targets(model, targets): | ||||||
|     return tcls, tbox, indices, av |     return tcls, tbox, indices, av | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch'): | def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=True, method='vision_batch', classes=None): | ||||||
|     """ |     """ | ||||||
|     Removes detections with lower object confidence score than 'conf_thres' |     Removes detections with lower object confidence score than 'conf_thres' | ||||||
|     Non-Maximum Suppression to further filter detections. |     Non-Maximum Suppression to further filter detections. | ||||||
|  | @ -537,6 +537,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru | ||||||
|             conf, j = pred[:, 5:].max(1) |             conf, j = pred[:, 5:].max(1) | ||||||
|             pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) |             pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) | ||||||
| 
 | 
 | ||||||
|  |         # Filter by class | ||||||
|  |         if classes: | ||||||
|  |             pred = pred[(j.view(-1, 1) == torch.Tensor(classes)).any(1)] | ||||||
|  | 
 | ||||||
|         # Apply finite constraint |         # Apply finite constraint | ||||||
|         if not torch.isfinite(pred).all(): |         if not torch.isfinite(pred).all(): | ||||||
|             pred = pred[torch.isfinite(pred).all(1)] |             pred = pred[torch.isfinite(pred).all(1)] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue