updates
This commit is contained in:
parent
058bb7f38d
commit
e2a8f5bdce
19
models.py
19
models.py
|
@ -37,18 +37,18 @@ def create_modules(module_defs):
|
|||
modules.add_module('upsample_%d' % i, upsample)
|
||||
|
||||
elif module_def['type'] == 'route':
|
||||
layers = [int(x) for x in module_def["layers"].split(',')]
|
||||
layers = [int(x) for x in module_def['layers'].split(',')]
|
||||
filters = sum([output_filters[layer_i] for layer_i in layers])
|
||||
modules.add_module('route_%d' % i, EmptyLayer())
|
||||
|
||||
elif module_def['type'] == 'shortcut':
|
||||
filters = output_filters[int(module_def['from'])]
|
||||
modules.add_module("shortcut_%d" % i, EmptyLayer())
|
||||
modules.add_module('shortcut_%d' % i, EmptyLayer())
|
||||
|
||||
elif module_def["type"] == "yolo":
|
||||
anchor_idxs = [int(x) for x in module_def["mask"].split(",")]
|
||||
elif module_def['type'] == 'yolo':
|
||||
anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
|
||||
# Extract anchors
|
||||
anchors = [float(x) for x in module_def["anchors"].split(",")]
|
||||
anchors = [float(x) for x in module_def['anchors'].split(',')]
|
||||
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
|
||||
anchors = [anchors[i] for i in anchor_idxs]
|
||||
num_classes = int(module_def['classes'])
|
||||
|
@ -72,7 +72,6 @@ class EmptyLayer(nn.Module):
|
|||
|
||||
|
||||
class YOLOLayer(nn.Module):
|
||||
# YOLO Layer 0
|
||||
|
||||
def __init__(self, anchors, nC, img_dim, anchor_idxs):
|
||||
super(YOLOLayer, self).__init__()
|
||||
|
@ -104,15 +103,15 @@ class YOLOLayer(nn.Module):
|
|||
def forward(self, p, targets=None, requestPrecision=False, epoch=None):
|
||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||
|
||||
bs = p.shape[0]
|
||||
nG = p.shape[2]
|
||||
bs = p.shape[0] # batch size
|
||||
nG = p.shape[2] # number of grid points
|
||||
stride = self.img_dim / nG
|
||||
|
||||
if p.is_cuda and not self.grid_x.is_cuda:
|
||||
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
||||
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
|
||||
|
||||
# x.view(4, 650, 19, 19) -- > (4, 10, 19, 19, 65) # (bs, anchors, grid, grid, classes + xywh)
|
||||
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
|
||||
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||
|
||||
# Get outputs
|
||||
|
@ -255,7 +254,7 @@ def load_weights(self, weights_path):
|
|||
"""Parses and loads the weights stored in 'weights_path'"""
|
||||
|
||||
# Open the weights file
|
||||
fp = open(weights_path, "rb")
|
||||
fp = open(weights_path, 'rb')
|
||||
header = np.fromfile(fp, dtype=np.int32, count=5) # First five are header values
|
||||
|
||||
# Needed to write header when saving weights
|
||||
|
|
Loading…
Reference in New Issue