diff --git a/models.py b/models.py index 29a01eaf..9e962155 100755 --- a/models.py +++ b/models.py @@ -45,7 +45,7 @@ def create_modules(module_defs): elif module_def['type'] == 'upsample': # upsample = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') # WARNING: deprecated - upsample = Upsample(scale_factor=int(module_def['stride']), mode='nearest') + upsample = Upsample(scale_factor=int(module_def['stride'])) modules.add_module('upsample_%d' % i, upsample) elif module_def['type'] == 'route': @@ -131,6 +131,7 @@ class YOLOLayer(nn.Module): self.loss_means = torch.ones(6) self.yolo_layer = anchor_idxs[0] / nA # 2, 1, 0 self.stride = stride + self.nG = nG if ONNX_EXPORT: # use fully populated and reshaped tensors self.anchor_w = self.anchor_w.repeat((1, 1, nG, nG)).view(1, -1, 1) @@ -142,8 +143,8 @@ class YOLOLayer(nn.Module): def forward(self, p, targets=None, batch_report=False, var=None): FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor - bs = p.shape[0] # batch size - nG = p.shape[2] # number of grid points + bs = 1 if ONNX_EXPORT else p.shape[0] # batch size + nG = self.nG # number of grid points if p.is_cuda and not self.weights.is_cuda: self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda() @@ -285,7 +286,10 @@ class Darknet(nn.Module): x = module(x) elif module_def['type'] == 'route': layer_i = [int(x) for x in module_def['layers'].split(',')] - x = torch.cat([layer_outputs[i] for i in layer_i], 1) + if len(layer_i) == 1: + x = layer_outputs[layer_i[0]] + else: + x = torch.cat([layer_outputs[i] for i in layer_i], 1) elif module_def['type'] == 'shortcut': layer_i = int(module_def['from']) x = layer_outputs[-1] + layer_outputs[layer_i] @@ -328,7 +332,8 @@ class Darknet(nn.Module): if ONNX_EXPORT: # Produce a single-layer *.onnx model (upsample ops not working in PyTorch 1.0 export yet) - output = output[0] # first layer reshaped to 85 x 507 + output = output[1] # first layer reshaped to 85 x 507 + # output = torch.cat(output, 1) return output[5:85].t(), output[:4].t() # ONNX scores, boxes return sum(output) if is_training else torch.cat(output, 1)