This commit is contained in:
Glenn Jocher 2020-01-21 17:23:35 -08:00
parent 6ccf19038d
commit 5d73b190b0
2 changed files with 3 additions and 11 deletions

View File

@ -81,17 +81,13 @@ def create_modules(module_defs, img_size, arc):
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3) # Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
try: try:
if arc == 'defaultpw' or arc == 'Fdefaultpw': # default with positive weights if arc == 'default' or arc == 'Fdefault': # default
b = [-5.0, -5.0] # obj, cls b = [-5.0, -5.0] # obj, cls
elif arc == 'default': # default no pw (40 cls, 80 obj)
b = [-5.0, -5.0]
elif arc == 'uBCE': # unified BCE (80 classes) elif arc == 'uBCE': # unified BCE (80 classes)
b = [0, -9.0] b = [0, -9.0]
elif arc == 'uCE': # unified CE (1 background + 80 classes) elif arc == 'uCE': # unified CE (1 background + 80 classes)
b = [10, -0.1] b = [10, -0.1]
elif arc == 'Fdefault': # Focal default no pw (28 cls, 21 obj, no pw) elif arc == 'uFBCE': # unified FocalBCE (5120 obj, 80 classes)
b = [-2.1, -1.8]
elif arc == 'uFBCE' or arc == 'uFBCEpw': # unified FocalBCE (5120 obj, 80 classes)
b = [0, -6.5] b = [0, -6.5]
elif arc == 'uFCE': # unified FocalCE (64 cls, 1 background + 80 classes) elif arc == 'uFCE': # unified FocalCE (64 cls, 1 background + 80 classes)
b = [7.7, -1.1] b = [7.7, -1.1]

View File

@ -58,10 +58,6 @@ def train():
accumulate = opt.accumulate # effective bs = batch_size * accumulate = 16 * 4 = 64 accumulate = opt.accumulate # effective bs = batch_size * accumulate = 16 * 4 = 64
weights = opt.weights # initial training weights weights = opt.weights # initial training weights
if 'pw' not in opt.arc: # remove BCELoss positive weights
hyp['cls_pw'] = 1.
hyp['obj_pw'] = 1.
# Initialize # Initialize
init_seeds() init_seeds()
if opt.multi_scale: if opt.multi_scale:
@ -413,7 +409,7 @@ if __name__ == '__main__':
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--weights', type=str, default='weights/ultralytics68.pt', help='initial weights') parser.add_argument('--weights', type=str, default='weights/ultralytics68.pt', help='initial weights')
parser.add_argument('--arc', type=str, default='default', help='yolo architecture') # defaultpw, uCE, uBCE parser.add_argument('--arc', type=str, default='default', help='yolo architecture') # default, uCE, uBCE
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)') parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--adam', action='store_true', help='use adam optimizer') parser.add_argument('--adam', action='store_true', help='use adam optimizer')