updates
This commit is contained in:
parent
573e8c2840
commit
5580694970
8
train.py
8
train.py
|
@ -143,15 +143,15 @@ def train(
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
t, t0 = time.time(), time.time()
|
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
model_info(model)
|
model_info(model)
|
||||||
nb = len(dataloader)
|
nb = len(dataloader)
|
||||||
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
||||||
n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches
|
n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches
|
||||||
os.remove('train_batch0.jpg') if os.path.exists('train_batch0.jpg') else None
|
for f in glob.glob('train_batch*.jpg') + glob.glob('test_batch*.jpg'):
|
||||||
os.remove('test_batch0.jpg') if os.path.exists('test_batch0.jpg') else None
|
os.remove(f)
|
||||||
|
t, t0 = time.time(), time.time()
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs):
|
||||||
model.train()
|
model.train()
|
||||||
print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
|
print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time'))
|
||||||
|
@ -282,7 +282,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
|
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
|
||||||
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
parser.add_argument('--resume', action='store_true', help='resume training flag')
|
||||||
parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
|
parser.add_argument('--transfer', action='store_true', help='transfer learning flag')
|
||||||
parser.add_argument('--num-workers', type=int, default=2, help='number of Pytorch DataLoader workers')
|
parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers')
|
||||||
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method')
|
parser.add_argument('--dist-url', default='tcp://127.0.0.1:9999', type=str, help='distributed training init method')
|
||||||
parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
|
parser.add_argument('--rank', default=0, type=int, help='distributed training node rank')
|
||||||
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
|
parser.add_argument('--world-size', default=1, type=int, help='number of nodes for distributed training')
|
||||||
|
|
|
@ -318,7 +318,7 @@ def letterbox(img, new_shape=416, color=(127.5, 127.5, 127.5), mode='auto'):
|
||||||
|
|
||||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_AREA) # resized, no border
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) # resized, no border
|
||||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square
|
||||||
return img, ratio, dw, dh
|
return img, ratio, dw, dh
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue