diff --git a/train.py b/train.py index abd5870f..707d1580 100644 --- a/train.py +++ b/train.py @@ -53,7 +53,7 @@ def train( if resume: # Load previously saved model if transfer: # Transfer learning - chkpt = torch.load(weights + 'yolov3.pt', map_location=device) + chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device) model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255}, strict=False) for p in model.parameters(): @@ -99,9 +99,9 @@ def train( sampler=sampler) # Start training - nB = len(dataloader) t = time.time() model_info(model) + nB = len(dataloader) n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches for epoch in range(start_epoch, epochs): model.train()