changed the criteria for the best weight file (#356)

* changed the criteria for the best weight file

changed the criteria for the best weight file from loss to mAP

I trained the model on my custom dataset. But I failed to get a good results when I load the weight file that has the lowest loss on test dataset. 

I thought that the loss used in YOLO is not proper criteria for detection performance. So I changed the criteria from loss to mAP.

what do you think of this?

* Update train.py
This commit is contained in:
Yonghye Kwon 2019-07-02 19:24:18 +09:00 committed by Glenn Jocher
parent 1fd871abd8
commit ccf757b3ea
1 changed files with 6 additions and 6 deletions

View File

@ -77,7 +77,7 @@ def train(
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0 start_epoch = 0
best_loss = float('inf') best_map = 0.
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255) nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
if opt.resume or opt.transfer: # Load previously saved model if opt.resume or opt.transfer: # Load previously saved model
if opt.transfer: # Transfer learning if opt.transfer: # Transfer learning
@ -256,17 +256,17 @@ def train(
with open('results.txt', 'a') as file: with open('results.txt', 'a') as file:
file.write(s + '%11.3g' * 5 % results + '\n') # P, R, mAP, F1, test_loss file.write(s + '%11.3g' * 5 % results + '\n') # P, R, mAP, F1, test_loss
# Update best loss # Update best map
test_loss = results[4] test_map = results[2]
if test_loss < best_loss: if test_map > best_map:
best_loss = test_loss best_map = test_map
# Save training results # Save training results
save = (not opt.nosave) or (epoch == epochs - 1) save = (not opt.nosave) or (epoch == epochs - 1)
if save: if save:
# Create checkpoint # Create checkpoint
chkpt = {'epoch': epoch, chkpt = {'epoch': epoch,
'best_loss': best_loss, 'best_map': best_map,
'model': model.module.state_dict() if type( 'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(), model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()} 'optimizer': optimizer.state_dict()}