updates
This commit is contained in:
parent
6130b70fe7
commit
204594f299
20
models.py
20
models.py
|
@ -169,7 +169,6 @@ class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||||
class YOLOLayer(nn.Module):
|
class YOLOLayer(nn.Module):
|
||||||
def __init__(self, anchors, nc, img_size, yolo_index, arc):
|
def __init__(self, anchors, nc, img_size, yolo_index, arc):
|
||||||
super(YOLOLayer, self).__init__()
|
super(YOLOLayer, self).__init__()
|
||||||
|
|
||||||
self.anchors = torch.Tensor(anchors)
|
self.anchors = torch.Tensor(anchors)
|
||||||
self.na = len(anchors) # number of anchors (3)
|
self.na = len(anchors) # number of anchors (3)
|
||||||
self.nc = nc # number of classes (80)
|
self.nc = nc # number of classes (80)
|
||||||
|
@ -213,27 +212,12 @@ class YOLOLayer(nn.Module):
|
||||||
return p_cls, xy * ng, wh
|
return p_cls, xy * ng, wh
|
||||||
|
|
||||||
else: # inference
|
else: # inference
|
||||||
# s = 1.5 # scale_xy (pxy = pxy * s - (s - 1) / 2)
|
|
||||||
io = p.clone() # inference output
|
io = p.clone() # inference output
|
||||||
io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid_xy # xy
|
io[..., :2] = torch.sigmoid(io[..., :2]) + self.grid_xy # xy
|
||||||
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method
|
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method
|
||||||
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
|
|
||||||
io[..., :4] *= self.stride
|
io[..., :4] *= self.stride
|
||||||
|
torch.sigmoid_(io[..., 4:])
|
||||||
if 'default' in self.arc: # seperate obj and cls
|
return io.view(bs, -1, self.no), p # view [1, 3, 13, 13, 85] as [1, 507, 85]
|
||||||
torch.sigmoid_(io[..., 4:])
|
|
||||||
elif 'BCE' in self.arc: # unified BCE (80 classes)
|
|
||||||
torch.sigmoid_(io[..., 5:])
|
|
||||||
io[..., 4] = 1
|
|
||||||
elif 'CE' in self.arc: # unified CE (1 background + 80 classes)
|
|
||||||
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
|
|
||||||
io[..., 4] = 1
|
|
||||||
|
|
||||||
if self.nc == 1:
|
|
||||||
io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235
|
|
||||||
|
|
||||||
# reshape from [1, 3, 13, 13, 85] to [1, 507, 85]
|
|
||||||
return io.view(bs, -1, self.no), p
|
|
||||||
|
|
||||||
|
|
||||||
class Darknet(nn.Module):
|
class Darknet(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue