This commit is contained in:
Glenn Jocher 2020-03-09 16:44:26 -07:00
parent 6130b70fe7
commit 204594f299
1 changed files with 2 additions and 18 deletions

View File

@ -169,7 +169,6 @@ class Mish(nn.Module): # https://github.com/digantamisra98/Mish
class YOLOLayer(nn.Module): class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, yolo_index, arc): def __init__(self, anchors, nc, img_size, yolo_index, arc):
super(YOLOLayer, self).__init__() super(YOLOLayer, self).__init__()
self.anchors = torch.Tensor(anchors) self.anchors = torch.Tensor(anchors)
self.na = len(anchors) # number of anchors (3) self.na = len(anchors) # number of anchors (3)
self.nc = nc # number of classes (80) self.nc = nc # number of classes (80)
@ -213,27 +212,12 @@ class YOLOLayer(nn.Module):
return p_cls, xy * ng, wh return p_cls, xy * ng, wh
else: # inference else: # inference
# s = 1.5 # scale_xy (pxy = pxy * s - (s - 1) / 2)
io = p.clone() # inference output io = p.clone() # inference output
io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid_xy # xy io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid_xy # xy
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
io[..., :4] *= self.stride io[..., :4] *= self.stride
if 'default' in self.arc: # seperate obj and cls
torch.sigmoid_(io[..., 4:]) torch.sigmoid_(io[..., 4:])
elif 'BCE' in self.arc: # unified BCE (80 classes) return io.view(bs, -1, self.no), p # view [1, 3, 13, 13, 85] as [1, 507, 85]
torch.sigmoid_(io[..., 5:])
io[..., 4] = 1
elif 'CE' in self.arc: # unified CE (1 background + 80 classes)
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
io[..., 4] = 1
if self.nc == 1:
io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235
# reshape from [1, 3, 13, 13, 85] to [1, 507, 85]
return io.view(bs, -1, self.no), p
class Darknet(nn.Module): class Darknet(nn.Module):