diff --git a/models.py b/models.py index 804d3a4a..5b8a5207 100755 --- a/models.py +++ b/models.py @@ -81,9 +81,12 @@ def create_modules(module_defs, img_size): if arc == 'normal': bias[:, 4] -= 5.0 # obj bias[:, 5:] -= 4.0 # cls - elif arc == 'uCE': + elif arc == 'uCE': # unified CE (1 background + 80 classes) bias[:, 4] += 3.0 # obj bias[:, 5:] -= 4.0 # cls + elif arc == 'uBCE': # unified BCE (80 classes) + bias[:, 4] -= 5.0 # obj + bias[:, 5:] -= 4.0 # cls module_list[-1][0].bias = torch.nn.Parameter(bias.view(-1)) # for l in model.yolo_layers: # print pretrained biases