updates
This commit is contained in:
parent
feeaf734f2
commit
0bd763f528
12
train.py
12
train.py
|
@ -72,6 +72,8 @@ def train(
|
||||||
p.requires_grad = True if p.shape[0] == nf else False
|
p.requires_grad = True if p.shape[0] == nf else False
|
||||||
|
|
||||||
else: # resume from latest.pt
|
else: # resume from latest.pt
|
||||||
|
if opt.bucket:
|
||||||
|
os.system('gsutil cp gs://%s/latest.pt %s' % (opt.bucket, latest)) # download from bucket
|
||||||
chkpt = torch.load(latest, map_location=device) # load checkpoint
|
chkpt = torch.load(latest, map_location=device) # load checkpoint
|
||||||
model.load_state_dict(chkpt['model'])
|
model.load_state_dict(chkpt['model'])
|
||||||
|
|
||||||
|
@ -262,6 +264,8 @@ def train(
|
||||||
|
|
||||||
# Save latest checkpoint
|
# Save latest checkpoint
|
||||||
torch.save(chkpt, latest)
|
torch.save(chkpt, latest)
|
||||||
|
if opt.bucket:
|
||||||
|
os.system('gsutil cp %s gs://%s' % (latest, opt.bucket)) # upload to bucket
|
||||||
|
|
||||||
# Save best checkpoint
|
# Save best checkpoint
|
||||||
if best_fitness == fitness:
|
if best_fitness == fitness:
|
||||||
|
@ -284,11 +288,11 @@ def print_mutation(hyp, results):
|
||||||
c = '%11.3g' * len(results) % results # results (P, R, mAP, F1, test_loss)
|
c = '%11.3g' * len(results) % results # results (P, R, mAP, F1, test_loss)
|
||||||
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
|
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
|
||||||
|
|
||||||
if opt.cloud:
|
if opt.bucket:
|
||||||
os.system('gsutil cp gs://yolov4/evolve.txt .') # download evolve.txt
|
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt
|
||||||
with open('evolve.txt', 'a') as f: # append result
|
with open('evolve.txt', 'a') as f: # append result
|
||||||
f.write(c + b + '\n')
|
f.write(c + b + '\n')
|
||||||
os.system('gsutil cp evolve.txt gs://yolov4') # upload evolve.txt
|
os.system('gsutil cp evolve.txt gs://%s' % opt.bucket) # upload evolve.txt
|
||||||
else:
|
else:
|
||||||
with open('evolve.txt', 'a') as f:
|
with open('evolve.txt', 'a') as f:
|
||||||
f.write(c + b + '\n')
|
f.write(c + b + '\n')
|
||||||
|
@ -311,7 +315,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
||||||
parser.add_argument('--xywh', action='store_true', help='use xywh loss instead of GIoU loss')
|
parser.add_argument('--xywh', action='store_true', help='use xywh loss instead of GIoU loss')
|
||||||
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
||||||
parser.add_argument('--cloud', action='store_true', help='train/evolve to a cloud source')
|
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
|
||||||
parser.add_argument('--var', default=0, type=int, help='debug variable')
|
parser.add_argument('--var', default=0, type=int, help='debug variable')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
Loading…
Reference in New Issue