This commit is contained in:
Glenn Jocher 2018-09-02 12:59:39 +02:00
parent 058bb7f38d
commit e2a8f5bdce
1 changed files with 9 additions and 10 deletions

View File

@ -37,18 +37,18 @@ def create_modules(module_defs):
modules.add_module('upsample_%d' % i, upsample)
elif module_def['type'] == 'route':
layers = [int(x) for x in module_def["layers"].split(',')]
layers = [int(x) for x in module_def['layers'].split(',')]
filters = sum([output_filters[layer_i] for layer_i in layers])
modules.add_module('route_%d' % i, EmptyLayer())
elif module_def['type'] == 'shortcut':
filters = output_filters[int(module_def['from'])]
modules.add_module("shortcut_%d" % i, EmptyLayer())
modules.add_module('shortcut_%d' % i, EmptyLayer())
elif module_def["type"] == "yolo":
anchor_idxs = [int(x) for x in module_def["mask"].split(",")]
elif module_def['type'] == 'yolo':
anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
# Extract anchors
anchors = [float(x) for x in module_def["anchors"].split(",")]
anchors = [float(x) for x in module_def['anchors'].split(',')]
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors = [anchors[i] for i in anchor_idxs]
num_classes = int(module_def['classes'])
@ -72,7 +72,6 @@ class EmptyLayer(nn.Module):
class YOLOLayer(nn.Module):
# YOLO Layer 0
def __init__(self, anchors, nC, img_dim, anchor_idxs):
super(YOLOLayer, self).__init__()
@ -104,15 +103,15 @@ class YOLOLayer(nn.Module):
def forward(self, p, targets=None, requestPrecision=False, epoch=None):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0]
nG = p.shape[2]
bs = p.shape[0] # batch size
nG = p.shape[2] # number of grid points
stride = self.img_dim / nG
if p.is_cuda and not self.grid_x.is_cuda:
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
# x.view(4, 650, 19, 19) -- > (4, 10, 19, 19, 65) # (bs, anchors, grid, grid, classes + xywh)
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
# Get outputs
@ -255,7 +254,7 @@ def load_weights(self, weights_path):
"""Parses and loads the weights stored in 'weights_path'"""
# Open the weights file
fp = open(weights_path, "rb")
fp = open(weights_path, 'rb')
header = np.fromfile(fp, dtype=np.int32, count=5) # First five are header values
# Needed to write header when saving weights