From a59350852bce0d4f7a0a1e65bce2a569b389358a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 10 Oct 2019 22:54:20 +0200 Subject: [PATCH] updates --- detect.py | 18 ++++++++++++++++-- utils/utils.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/detect.py b/detect.py index 9416c184..52990a5a 100644 --- a/detect.py +++ b/detect.py @@ -27,6 +27,12 @@ def detect(save_txt=False, save_img=False): else: # darknet format _ = load_darknet_weights(model, weights) + # Second-stage classifier + classify = False + if classify: + modelc = torch_utils.load_classifier(name='resnet101', n=2) # initialize + modelc.load_state_dict(torch.load('resnet101.pt', map_location=device)['model']) # load weights + # Fuse Conv2d + BatchNorm2d layers # model.fuse() @@ -67,12 +73,20 @@ def detect(save_txt=False, save_img=False): img = torch.from_numpy(img).to(device) if img.ndimension() == 3: img = img.unsqueeze(0) - pred, _ = model(img) + pred = model(img)[0] if opt.half: pred = pred.float() - for i, det in enumerate(non_max_suppression(pred, opt.conf_thres, opt.nms_thres)): # detections per image + # Apply NMS + pred = non_max_suppression(pred, opt.conf_thres, opt.nms_thres) + + # Apply + if classify: + pred = apply_classifier(pred, modelc, img, im0s) + + # Process detections + for i, det in enumerate(pred): # detections per image if webcam: # batch_size >= 1 p, s, im0 = path[i], '%g: ' % i, im0s[i] else: diff --git a/utils/utils.py b/utils/utils.py index 322ce246..3746ca98 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -720,6 +720,46 @@ def print_mutation(hyp, results, bucket=''): os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt +def apply_classifier(x, model, img, im0): + # applies a second stage classifier to yolo outputs + + for i, d in enumerate(x): # per image + if d is not None and len(d): + d = d.clone() + + # Reshape and pad cutouts + b = xyxy2xywh(d[:, :4]) # boxes + b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square + b[:, 2:] = b[:, 2:] * 1.0 + 0 # pad + d[:, :4] = xywh2xyxy(b).long() + + # Rescale boxes from img_size to im0 size + scale_coords(img.shape[2:], d[:, :4], im0.shape) + + # Classes + pred_cls1 = d[:, 6].long() + ims = [] + j = 0 + for a in d: # per item + j += 1 + cutout = im0[int(a[1]):int(a[3]), int(a[0]):int(a[2])] + im = cv2.resize(cutout, (128, 128)) # BGR + cv2.imwrite('test%i.jpg' % j, cutout) + + im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 + im = np.expand_dims(im, axis=0) # add batch dim + im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32 + im /= 255.0 # 0 - 255 to 0.0 - 1.0 + ims.append(im) + + ims = torch.Tensor(np.concatenate(ims, 0)) # to torch + pred_cls2 = model(ims).argmax(1) # classifier prediction + + # x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections + + return x + + def fitness(x): # Returns fitness (for use with results.txt or evolve.txt) return x[:, 2] * 0.8 + x[:, 3] * 0.2 # weighted mAP and F1 combination