updates
This commit is contained in:
parent
0040c85b9a
commit
ff7f73b642
9
train.py
9
train.py
|
@ -112,9 +112,9 @@ def train(cfg,
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_fitness = 0.
|
best_fitness = 0.
|
||||||
|
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||||
if opt.resume or opt.transfer: # Load previously saved model
|
if opt.resume or opt.transfer: # Load previously saved model
|
||||||
if opt.transfer: # Transfer learning
|
if opt.transfer: # Transfer learning
|
||||||
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
|
||||||
chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device)
|
chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device)
|
||||||
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
||||||
strict=False)
|
strict=False)
|
||||||
|
@ -208,7 +208,8 @@ def train(cfg,
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
print(('\n' + '%10s' * 9) %
|
print(('\n' + '%10s' * 9) %
|
||||||
('Epoch', 'gpu_mem', 'GIoU/xy', 'wh', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
('Epoch', 'gpu_mem', 'GIoU/xy', 'wh', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
||||||
|
@ -232,12 +233,12 @@ def train(cfg,
|
||||||
|
|
||||||
mloss = torch.zeros(5).to(device) # mean losses
|
mloss = torch.zeros(5).to(device) # mean losses
|
||||||
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
||||||
for i, (imgs, targets, paths, _) in pbar:
|
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
||||||
|
ni = (i + nb * epoch) # number integrated batches (since train start)
|
||||||
imgs = imgs.to(device)
|
imgs = imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
# Multi-Scale training
|
# Multi-Scale training
|
||||||
ni = (i + nb * epoch) # number integrated batches (since train start)
|
|
||||||
if multi_scale:
|
if multi_scale:
|
||||||
if ni / accumulate % 10 == 0: # adjust (67% - 150%) every 10 batches
|
if ni / accumulate % 10 == 0: # adjust (67% - 150%) every 10 batches
|
||||||
img_size = random.randrange(img_sz_min, img_sz_max + 1) * 32
|
img_size = random.randrange(img_sz_min, img_sz_max + 1) * 32
|
||||||
|
|
Loading…
Reference in New Issue