updates
This commit is contained in:
parent
22f75469ac
commit
fdd5afa229
9
train.py
9
train.py
|
@ -359,18 +359,15 @@ if __name__ == '__main__':
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if not opt.evolve: # Train normally
|
if not opt.evolve: # Train normally
|
||||||
# Tensorboard support,
|
# Tensorboard support,
|
||||||
# start with "tensorboard --logdir=runs"
|
# start with "tensorboard --logdir=runs" then go to localhost:6006
|
||||||
# go to localhost:6006
|
|
||||||
tensorboard_support = True
|
tensorboard_support = True
|
||||||
if version_to_tuple(torch.__version__) >= version_to_tuple("1.1.0"):
|
if version_to_tuple(torch.__version__) >= version_to_tuple("1.1.0"):
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
tb_train_name = time.time()
|
tb_train_name = time.time()
|
||||||
print('here')
|
|
||||||
tb_writer = SummaryWriter('runs/{}'.format(tb_train_name))
|
tb_writer = SummaryWriter('runs/{}'.format(tb_train_name))
|
||||||
except:
|
except:
|
||||||
tensorboard_support = False
|
tensorboard_support = False
|
||||||
|
@ -381,7 +378,7 @@ if __name__ == '__main__':
|
||||||
epochs=opt.epochs,
|
epochs=opt.epochs,
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
accumulate=opt.accumulate,
|
accumulate=opt.accumulate,
|
||||||
write_to_tensorboard=tensorboard_results)
|
write_to_tensorboard=tensorboard_support)
|
||||||
|
|
||||||
else: # Evolve hyperparameters (optional)
|
else: # Evolve hyperparameters (optional)
|
||||||
opt.notest = True # only test final epoch
|
opt.notest = True # only test final epoch
|
||||||
|
|
Loading…
Reference in New Issue