This commit is contained in:
Glenn Jocher 2019-03-05 17:10:34 +01:00
parent 2c2d7bc63b
commit a2ad00d6fc
2 changed files with 5 additions and 4 deletions

View File

@ -85,8 +85,9 @@ def train(
# Set scheduler # Set scheduler
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
model_info(model)
t0 = time.time() t0 = time.time()
model_info(model)
n_burnin = min(dataloader.nB / 5, 1000) # number of burn-in batches
for epoch in range(epochs): for epoch in range(epochs):
epoch += start_epoch epoch += start_epoch
@ -118,8 +119,8 @@ def train(
continue continue
# SGD burn-in # SGD burn-in
if (epoch == 0) & (i <= 1000): if (epoch == 0) & (i <= n_burnin):
lr = lr0 * (i / 1000) ** 4 lr = lr0 * (i / n_burnin) ** 4
for g in optimizer.param_groups: for g in optimizer.param_groups:
g['lr'] = lr g['lr'] = lr

View File

@ -379,7 +379,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
if prediction.is_cuda: if prediction.is_cuda:
unique_labels = unique_labels.cuda(prediction.device) unique_labels = unique_labels.cuda(prediction.device)
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental) nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
for c in unique_labels: for c in unique_labels:
# Get the detections with class c # Get the detections with class c
dc = detections[detections[:, -1] == c] dc = detections[detections[:, -1] == c]