updates
This commit is contained in:
parent
2c2d7bc63b
commit
a2ad00d6fc
7
train.py
7
train.py
|
@ -85,8 +85,9 @@ def train(
|
||||||
# Set scheduler
|
# Set scheduler
|
||||||
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
||||||
|
|
||||||
model_info(model)
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
model_info(model)
|
||||||
|
n_burnin = min(dataloader.nB / 5, 1000) # number of burn-in batches
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
|
||||||
|
@ -118,8 +119,8 @@ def train(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# SGD burn-in
|
# SGD burn-in
|
||||||
if (epoch == 0) & (i <= 1000):
|
if (epoch == 0) & (i <= n_burnin):
|
||||||
lr = lr0 * (i / 1000) ** 4
|
lr = lr0 * (i / n_burnin) ** 4
|
||||||
for g in optimizer.param_groups:
|
for g in optimizer.param_groups:
|
||||||
g['lr'] = lr
|
g['lr'] = lr
|
||||||
|
|
||||||
|
|
|
@ -379,7 +379,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
||||||
if prediction.is_cuda:
|
if prediction.is_cuda:
|
||||||
unique_labels = unique_labels.cuda(prediction.device)
|
unique_labels = unique_labels.cuda(prediction.device)
|
||||||
|
|
||||||
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||||
for c in unique_labels:
|
for c in unique_labels:
|
||||||
# Get the detections with class c
|
# Get the detections with class c
|
||||||
dc = detections[detections[:, -1] == c]
|
dc = detections[detections[:, -1] == c]
|
||||||
|
|
Loading…
Reference in New Issue