changed the criteria for the best weight file (#356)
* changed the criteria for the best weight file changed the criteria for the best weight file from loss to mAP I trained the model on my custom dataset. But I failed to get a good results when I load the weight file that has the lowest loss on test dataset. I thought that the loss used in YOLO is not proper criteria for detection performance. So I changed the criteria from loss to mAP. what do you think of this? * Update train.py
This commit is contained in:
parent
1fd871abd8
commit
ccf757b3ea
12
train.py
12
train.py
|
@ -77,7 +77,7 @@ def train(
|
||||||
|
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_loss = float('inf')
|
best_map = 0.
|
||||||
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||||
if opt.resume or opt.transfer: # Load previously saved model
|
if opt.resume or opt.transfer: # Load previously saved model
|
||||||
if opt.transfer: # Transfer learning
|
if opt.transfer: # Transfer learning
|
||||||
|
@ -256,17 +256,17 @@ def train(
|
||||||
with open('results.txt', 'a') as file:
|
with open('results.txt', 'a') as file:
|
||||||
file.write(s + '%11.3g' * 5 % results + '\n') # P, R, mAP, F1, test_loss
|
file.write(s + '%11.3g' * 5 % results + '\n') # P, R, mAP, F1, test_loss
|
||||||
|
|
||||||
# Update best loss
|
# Update best map
|
||||||
test_loss = results[4]
|
test_map = results[2]
|
||||||
if test_loss < best_loss:
|
if test_map > best_map:
|
||||||
best_loss = test_loss
|
best_map = test_map
|
||||||
|
|
||||||
# Save training results
|
# Save training results
|
||||||
save = (not opt.nosave) or (epoch == epochs - 1)
|
save = (not opt.nosave) or (epoch == epochs - 1)
|
||||||
if save:
|
if save:
|
||||||
# Create checkpoint
|
# Create checkpoint
|
||||||
chkpt = {'epoch': epoch,
|
chkpt = {'epoch': epoch,
|
||||||
'best_loss': best_loss,
|
'best_map': best_map,
|
||||||
'model': model.module.state_dict() if type(
|
'model': model.module.state_dict() if type(
|
||||||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'optimizer': optimizer.state_dict()}
|
||||||
|
|
Loading…
Reference in New Issue