This commit is contained in:
Glenn Jocher 2019-01-02 16:32:38 +01:00
parent 7283f26f6f
commit b181c61f4b
3 changed files with 37 additions and 37 deletions

View File

@ -24,7 +24,7 @@ def detect(
device = torch_utils.select_device() device = torch_utils.select_device()
print("Using device: \"{}\"".format(device)) print("Using device: \"{}\"".format(device))
# os.system('rm -rf ' + output) os.system('rm -rf ' + output)
os.makedirs(output, exist_ok=True) os.makedirs(output, exist_ok=True)
data_config = parse_data_config(data_config_path) data_config = parse_data_config(data_config_path)
@ -66,6 +66,7 @@ def detect(
# Get detections # Get detections
with torch.no_grad(): with torch.no_grad():
# cv2.imwrite('zidane_416.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # letterboxed
img = torch.from_numpy(img).unsqueeze(0).to(device) img = torch.from_numpy(img).unsqueeze(0).to(device)
# pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True); return # ONNX export # pred = torch.onnx._export(model, img, 'weights/model.onnx', verbose=True); return # ONNX export
pred = model(img) pred = model(img)

View File

@ -89,7 +89,7 @@ class Upsample(torch.nn.Module):
self.mode = mode self.mode = mode
def forward(self, x): def forward(self, x):
return nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
class YOLOLayer(nn.Module): class YOLOLayer(nn.Module):
@ -120,9 +120,10 @@ class YOLOLayer(nn.Module):
nG = int(self.img_dim / stride) # number grid points nG = int(self.img_dim / stride) # number grid points
self.grid_x = torch.arange(nG).repeat(nG, 1).view([1, 1, nG, nG]).float() self.grid_x = torch.arange(nG).repeat(nG, 1).view([1, 1, nG, nG]).float()
self.grid_y = torch.arange(nG).repeat(nG, 1).t().view([1, 1, nG, nG]).float() self.grid_y = torch.arange(nG).repeat(nG, 1).t().view([1, 1, nG, nG]).float()
self.scaled_anchors = torch.FloatTensor([(a_w / stride, a_h / stride) for a_w, a_h in anchors]) self.grid_y = torch.arange(nG).repeat(nG, 1).t().view([1, 1, nG, nG]).float()
self.anchor_w = self.scaled_anchors[:, 0:1].view((1, nA, 1, 1)) self.anchor_wh = torch.FloatTensor([(a_w / stride, a_h / stride) for a_w, a_h in anchors]) # scale anchors
self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1)) self.anchor_w = self.anchor_wh[:, 0:1].view((1, nA, 1, 1))
self.anchor_h = self.anchor_wh[:, 1:2].view((1, nA, 1, 1))
self.weights = class_weights() self.weights = class_weights()
self.loss_means = torch.ones(6) self.loss_means = torch.ones(6)
@ -177,7 +178,7 @@ class YOLOLayer(nn.Module):
gy + height / 2), 4) # x1y1x2y2 gy + height / 2), 4) # x1y1x2y2
tx, ty, tw, th, mask, tcls, TP, FP, FN, TC = \ tx, ty, tw, th, mask, tcls, TP, FP, FN, TC = \
build_targets(p_boxes, p_conf, p_cls, targets, self.scaled_anchors, self.nA, self.nC, nG, batch_report) build_targets(p_boxes, p_conf, p_cls, targets, self.anchor_wh, self.nA, self.nC, nG, batch_report)
tcls = tcls[mask] tcls = tcls[mask]
if x.is_cuda: if x.is_cuda:
@ -319,8 +320,8 @@ class Darknet(nn.Module):
if ONNX_export: if ONNX_export:
# Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet) # Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet)
output = output[0].squeeze().transpose(0, 1) # first layer reshaped to 85 x 507 output = output[0].squeeze().transpose(0, 1) # first layer reshaped to 85 x 507
output[5:] = torch.nn.functional.softmax(torch.sigmoid(output[5:]) * output[4:5], dim=0) # SSD-like conf output[5:85] = F.softmax(output[5:85], dim=0) * output[4:5] # SSD-like conf
return output[5:], output[:4] # ONNX scores, boxes return output[5:85], output[:4] # ONNX scores, boxes
return sum(output) if is_training else torch.cat(output, 1) return sum(output) if is_training else torch.cat(output, 1)

View File

@ -309,8 +309,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# cross-class NMS (experimental) # cross-class NMS (experimental)
cross_class_nms = False cross_class_nms = False
if cross_class_nms: if cross_class_nms:
# thresh = 0.85
thresh = nms_thres
a = pred.clone() a = pred.clone()
_, indices = torch.sort(-a[:, 4], 0) # sort best to worst _, indices = torch.sort(-a[:, 4], 0) # sort best to worst
a = a[indices] a = a[indices]
@ -325,7 +323,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
if len(close) > 0: if len(close) > 0:
close = close + i + 1 close = close + i + 1
iou = bbox_iou(a[i:i + 1, :4], a[close.squeeze(), :4].reshape(-1, 4), x1y1x2y2=False) iou = bbox_iou(a[i:i + 1, :4], a[close.squeeze(), :4].reshape(-1, 4), x1y1x2y2=False)
bad = close[iou > thresh] bad = close[iou > nms_thres]
if len(bad) > 0: if len(bad) > 0:
mask = torch.ones(len(a)).type(torch.ByteTensor) mask = torch.ones(len(a)).type(torch.ByteTensor)
@ -333,13 +331,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
a = a[mask] a = a[mask]
pred = a pred = a
x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3] # Experiment: Prior class size rejection
a = w * h # area # x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
ar = w / (h + 1e-16) # aspect ratio # a = w * h # area
# ar = w / (h + 1e-16) # aspect ratio
log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
# n = len(w) # n = len(w)
# log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
# shape_likelihood = np.zeros((n, 60), dtype=np.float32) # shape_likelihood = np.zeros((n, 60), dtype=np.float32)
# x = np.concatenate((log_w.reshape(-1, 1), log_h.reshape(-1, 1)), 1) # x = np.concatenate((log_w.reshape(-1, 1), log_h.reshape(-1, 1)), 1)
# from scipy.stats import multivariate_normal # from scipy.stats import multivariate_normal
@ -348,7 +345,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1) class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
v = ((pred[:, 4] > conf_thres) & (class_prob > .3)) v = ((pred[:, 4] > conf_thres) & (class_prob > .3)) # TODO examine arbitrary 0.3 thres here
v = v.nonzero().squeeze() v = v.nonzero().squeeze()
if len(v.shape) == 0: if len(v.shape) == 0:
v = v.unsqueeze(0) v = v.unsqueeze(0)
@ -375,44 +372,43 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
nms_style = 'OR' # 'AND' or 'OR' (classical) nms_style = 'OR' # 'AND' or 'OR' (classical)
for c in unique_labels: for c in unique_labels:
# Get the detections with the particular class # Get the detections with the particular class
detections_class = detections[detections[:, -1] == c] det_class = detections[detections[:, -1] == c]
# Sort the detections by maximum objectness confidence # Sort the detections by maximum objectness confidence
_, conf_sort_index = torch.sort(detections_class[:, 4], descending=True) _, conf_sort_index = torch.sort(det_class[:, 4], descending=True)
detections_class = detections_class[conf_sort_index] det_class = det_class[conf_sort_index]
# Perform non-maximum suppression # Perform non-maximum suppression
max_detections = [] det_max = []
if nms_style == 'OR': # Classical NMS if nms_style == 'OR': # Classical NMS
while detections_class.shape[0]: while det_class.shape[0]:
# Get detection with highest confidence and save as max detection # Get detection with highest confidence and save as max detection
max_detections.append(detections_class[0].unsqueeze(0)) det_max.append(det_class[0].unsqueeze(0))
# Stop if we're at the last detection # Stop if we're at the last detection
if len(detections_class) == 1: if len(det_class) == 1:
break break
# Get the IOUs for all boxes with lower confidence # Get the IOUs for all boxes with lower confidence
ious = bbox_iou(max_detections[-1], detections_class[1:]) ious = bbox_iou(det_max[-1], det_class[1:])
# Remove detections with IoU >= NMS threshold # Remove detections with IoU >= NMS threshold
detections_class = detections_class[1:][ious < nms_thres] det_class = det_class[1:][ious < nms_thres]
elif nms_style == 'AND': # 'AND'-style NMS, at least two boxes must share commonality to pass, single boxes erased elif nms_style == 'AND': # 'AND'-style NMS: >=2 boxes must share commonality to pass, single boxes erased
while detections_class.shape[0]: while det_class.shape[0]:
if len(detections_class) == 1: if len(det_class) == 1:
break break
ious = bbox_iou(detections_class[:1], detections_class[1:]) ious = bbox_iou(det_class[:1], det_class[1:])
if ious.max() > 0.5: if ious.max() > 0.5:
max_detections.append(detections_class[0].unsqueeze(0)) det_max.append(det_class[0].unsqueeze(0))
# Remove detections with IoU >= NMS threshold # Remove detections with IoU >= NMS threshold
detections_class = detections_class[1:][ious < nms_thres] det_class = det_class[1:][ious < nms_thres]
if len(max_detections) > 0: if len(det_max) > 0:
max_detections = torch.cat(max_detections).data det_max = torch.cat(det_max).data
# Add max detections to outputs # Add max detections to outputs
output[image_i] = max_detections if output[image_i] is None else torch.cat( output[image_i] = det_max if output[image_i] is None else torch.cat((output[image_i], det_max))
(output[image_i], max_detections))
return output return output
@ -426,6 +422,7 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
def coco_class_count(path='../coco/labels/train2014/'): def coco_class_count(path='../coco/labels/train2014/'):
# histogram of occurrences per class
import glob import glob
nC = 80 # number classes nC = 80 # number classes
@ -443,6 +440,7 @@ def plot_results():
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt') # import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt')
plt.figure(figsize=(16, 8)) plt.figure(figsize=(16, 8))
s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP'] s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
files = sorted(glob.glob('results*.txt')) files = sorted(glob.glob('results*.txt'))