updates
This commit is contained in:
parent
8f1becd55c
commit
d25190e15b
6
train.py
6
train.py
|
@ -64,7 +64,9 @@ def train(
|
||||||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
train_path = parse_data_cfg(data_cfg)['train']
|
data_cfg = parse_data_cfg(data_cfg)
|
||||||
|
train_path = data_cfg['train']
|
||||||
|
nc = data_cfg['classes'] # number of classes
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg, img_size).to(device)
|
model = Darknet(cfg, img_size).to(device)
|
||||||
|
@ -145,7 +147,7 @@ def train(
|
||||||
# Start training
|
# Start training
|
||||||
t, t0 = time.time(), time.time()
|
t, t0 = time.time(), time.time()
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
model_info(model)
|
model_info(model)
|
||||||
nb = len(dataloader)
|
nb = len(dataloader)
|
||||||
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
||||||
|
|
|
@ -49,11 +49,11 @@ def model_info(model):
|
||||||
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
|
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
|
||||||
|
|
||||||
|
|
||||||
def labels_to_class_weights(labels):
|
def labels_to_class_weights(labels, nc=80):
|
||||||
# Get class weights (inverse frequency) from training labels
|
# Get class weights (inverse frequency) from training labels
|
||||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||||
classes = labels[:, 0].astype(np.int)
|
classes = labels[:, 0].astype(np.int)
|
||||||
weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class
|
weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class
|
||||||
weights /= weights.sum()
|
weights /= weights.sum()
|
||||||
return torch.Tensor(weights)
|
return torch.Tensor(weights)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue