updates
This commit is contained in:
parent
058bb7f38d
commit
e2a8f5bdce
19
models.py
19
models.py
|
@ -37,18 +37,18 @@ def create_modules(module_defs):
|
||||||
modules.add_module('upsample_%d' % i, upsample)
|
modules.add_module('upsample_%d' % i, upsample)
|
||||||
|
|
||||||
elif module_def['type'] == 'route':
|
elif module_def['type'] == 'route':
|
||||||
layers = [int(x) for x in module_def["layers"].split(',')]
|
layers = [int(x) for x in module_def['layers'].split(',')]
|
||||||
filters = sum([output_filters[layer_i] for layer_i in layers])
|
filters = sum([output_filters[layer_i] for layer_i in layers])
|
||||||
modules.add_module('route_%d' % i, EmptyLayer())
|
modules.add_module('route_%d' % i, EmptyLayer())
|
||||||
|
|
||||||
elif module_def['type'] == 'shortcut':
|
elif module_def['type'] == 'shortcut':
|
||||||
filters = output_filters[int(module_def['from'])]
|
filters = output_filters[int(module_def['from'])]
|
||||||
modules.add_module("shortcut_%d" % i, EmptyLayer())
|
modules.add_module('shortcut_%d' % i, EmptyLayer())
|
||||||
|
|
||||||
elif module_def["type"] == "yolo":
|
elif module_def['type'] == 'yolo':
|
||||||
anchor_idxs = [int(x) for x in module_def["mask"].split(",")]
|
anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
|
||||||
# Extract anchors
|
# Extract anchors
|
||||||
anchors = [float(x) for x in module_def["anchors"].split(",")]
|
anchors = [float(x) for x in module_def['anchors'].split(',')]
|
||||||
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
|
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
|
||||||
anchors = [anchors[i] for i in anchor_idxs]
|
anchors = [anchors[i] for i in anchor_idxs]
|
||||||
num_classes = int(module_def['classes'])
|
num_classes = int(module_def['classes'])
|
||||||
|
@ -72,7 +72,6 @@ class EmptyLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class YOLOLayer(nn.Module):
|
class YOLOLayer(nn.Module):
|
||||||
# YOLO Layer 0
|
|
||||||
|
|
||||||
def __init__(self, anchors, nC, img_dim, anchor_idxs):
|
def __init__(self, anchors, nC, img_dim, anchor_idxs):
|
||||||
super(YOLOLayer, self).__init__()
|
super(YOLOLayer, self).__init__()
|
||||||
|
@ -104,15 +103,15 @@ class YOLOLayer(nn.Module):
|
||||||
def forward(self, p, targets=None, requestPrecision=False, epoch=None):
|
def forward(self, p, targets=None, requestPrecision=False, epoch=None):
|
||||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||||
|
|
||||||
bs = p.shape[0]
|
bs = p.shape[0] # batch size
|
||||||
nG = p.shape[2]
|
nG = p.shape[2] # number of grid points
|
||||||
stride = self.img_dim / nG
|
stride = self.img_dim / nG
|
||||||
|
|
||||||
if p.is_cuda and not self.grid_x.is_cuda:
|
if p.is_cuda and not self.grid_x.is_cuda:
|
||||||
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
||||||
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
|
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
|
||||||
|
|
||||||
# x.view(4, 650, 19, 19) -- > (4, 10, 19, 19, 65) # (bs, anchors, grid, grid, classes + xywh)
|
# p.view(12, 255, 13, 13) -- > (12, 3, 13, 13, 80) # (bs, anchors, grid, grid, classes + xywh)
|
||||||
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||||
|
|
||||||
# Get outputs
|
# Get outputs
|
||||||
|
@ -255,7 +254,7 @@ def load_weights(self, weights_path):
|
||||||
"""Parses and loads the weights stored in 'weights_path'"""
|
"""Parses and loads the weights stored in 'weights_path'"""
|
||||||
|
|
||||||
# Open the weights file
|
# Open the weights file
|
||||||
fp = open(weights_path, "rb")
|
fp = open(weights_path, 'rb')
|
||||||
header = np.fromfile(fp, dtype=np.int32, count=5) # First five are header values
|
header = np.fromfile(fp, dtype=np.int32, count=5) # First five are header values
|
||||||
|
|
||||||
# Needed to write header when saving weights
|
# Needed to write header when saving weights
|
||||||
|
|
Loading…
Reference in New Issue