diff --git a/models.py b/models.py index aee33b65..804d3a4a 100755 --- a/models.py +++ b/models.py @@ -4,6 +4,7 @@ from utils.parse_config import * from utils.utils import * ONNX_EXPORT = False +arc = 'normal' # (normal, uCE, uBCE, uBCEs) detection architectures def create_modules(module_defs, img_size): @@ -77,8 +78,12 @@ def create_modules(module_defs, img_size): # Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3) bias = module_list[-1][0].bias.view(len(mask), -1) # 255 to 3x85 - bias[:, 4] -= 5.0 # obj - bias[:, 5:] -= 4.0 # cls + if arc == 'normal': + bias[:, 4] -= 5.0 # obj + bias[:, 5:] -= 4.0 # cls + elif arc == 'uCE': + bias[:, 4] += 3.0 # obj + bias[:, 5:] -= 4.0 # cls module_list[-1][0].bias = torch.nn.Parameter(bias.view(-1)) # for l in model.yolo_layers: # print pretrained biases @@ -168,7 +173,6 @@ class YOLOLayer(nn.Module): # io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method io[..., :4] *= self.stride - arc = 'normal' # (normal, uCE, uBCE, uBCEs) detection architectures if arc == 'normal': torch.sigmoid_(io[..., 4:]) elif arc == 'uCE': # unified CE (1 background + 80 classes)