This commit is contained in:
Glenn Jocher 2019-08-17 17:10:57 +02:00
parent b72fb74ad0
commit 926447e8c4
2 changed files with 9 additions and 4 deletions

View File

@ -155,9 +155,14 @@ class YOLOLayer(nn.Module):
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method # io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
io[..., :4] *= self.stride io[..., :4] *= self.stride
io[..., 4:] = torch.sigmoid(io[..., 4:]) # p_conf, p_cls arc = 'normal' # (normal, uCE uBCE) architecture types
# io[..., 4:] = F.softmax(io[..., 4:], dim=4) # unified detection CE if arc == 'normal':
# io[..., 4] = io[..., 5:].max(4)[0] # unified detection BCE io[..., 4:] = torch.sigmoid(io[..., 4:])
elif arc == 'uCE':
io[..., 4:] = F.softmax(io[..., 4:], dim=4) # unified detection CE
io[..., 4] = 1
elif arc == 'uBCE':
io[..., 4] = io[..., 5:].max(4)[0] # unified detection BCE
if self.nc == 1: if self.nc == 1:
io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235 io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235

View File

@ -63,7 +63,7 @@ def labels_to_class_weights(labels, nc=80):
# Prepend gridpoint count (for uCE trianing) # Prepend gridpoint count (for uCE trianing)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * ni, weights]) # prepend gridpoints to start # weights = np.hstack([gpi * ni - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
weights[weights == 0] = 1 # replace empty bins with 1 weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class weights = 1 / weights # number of targets per class