This commit is contained in:
Glenn Jocher 2019-05-08 17:29:23 +02:00
parent 573e8c2840
commit 5580694970
2 changed files with 5 additions and 5 deletions

View File

@ -143,15 +143,15 @@ def train(
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
# Start training
t, t0 = time.time(), time.time()
model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model_info(model)
nb = len(dataloader)
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches
os.remove('train_batch0.jpg') if os.path.exists('train_batch0.jpg') else None
os.remove('test_batch0.jpg') if os.path.exists('test_batch0.jpg') else None
for f in glob.glob('train_batch*.jpg') + glob.glob('test_batch*.jpg'):
os.remove(f)
t, t0 = time.time(), time.time()
for epoch in range(start_epoch, epochs):
model.train()
print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
@ -282,7 +282,7 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
parser.add_argument('--num-workers', type=int, default=2, help='number of Pytorch DataLoader workers')
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method')
parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')

View File

@ -318,7 +318,7 @@ def letterbox(img, new_shape=416, color=(127.5, 127.5, 127.5), mode='auto'):
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) # resized, no border
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square
return img, ratio, dw, dh