This commit is contained in:
Glenn Jocher 2019-02-09 22:14:07 +01:00
parent 1cd907c59b
commit f908f845ae
1 changed files with 10 additions and 5 deletions

View File

@ -45,7 +45,7 @@ def create_modules(module_defs):
elif module_def['type'] == 'upsample':
# upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') # WARNING: deprecated
upsample = Upsample(scale_factor=int(module_def['stride']), mode='nearest')
upsample = Upsample(scale_factor=int(module_def['stride']))
modules.add_module('upsample_%d' % i, upsample)
elif module_def['type'] == 'route':
@ -131,6 +131,7 @@ class YOLOLayer(nn.Module):
self.loss_means = torch.ones(6)
self.yolo_layer = anchor_idxs[0] / nA # 2, 1, 0
self.stride = stride
self.nG = nG
if ONNX_EXPORT: # use fully populated and reshaped tensors
self.anchor_w = self.anchor_w.repeat((1, 1, nG, nG)).view(1, -1, 1)
@ -142,8 +143,8 @@ class YOLOLayer(nn.Module):
def forward(self, p, targets=None, batch_report=False, var=None):
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
bs = p.shape[0] # batch size
nG = p.shape[2] # number of grid points
bs = 1 if ONNX_EXPORT else p.shape[0] # batch size
nG = self.nG # number of grid points
if p.is_cuda and not self.weights.is_cuda:
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
@ -285,6 +286,9 @@ class Darknet(nn.Module):
x = module(x)
elif module_def['type'] == 'route':
layer_i = [int(x) for x in module_def['layers'].split(',')]
if len(layer_i) == 1:
x = layer_outputs[layer_i[0]]
else:
x = torch.cat([layer_outputs[i] for i in layer_i], 1)
elif module_def['type'] == 'shortcut':
layer_i = int(module_def['from'])
@ -328,7 +332,8 @@ class Darknet(nn.Module):
if ONNX_EXPORT:
# Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet)
output = output[0] # first layer reshaped to 85 x 507
output = output[1] # first layer reshaped to 85 x 507
# output = torch.cat(output, 1)
return output[5:85].t(), output[:4].t() # ONNX scores, boxes
return sum(output) if is_training else torch.cat(output, 1)