updates
This commit is contained in:
parent
8f1becd55c
commit
d25190e15b
6
train.py
6
train.py
|
@ -64,7 +64,9 @@ def train(
|
|||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||
|
||||
# Configure run
|
||||
train_path = parse_data_cfg(data_cfg)['train']
|
||||
data_cfg = parse_data_cfg(data_cfg)
|
||||
train_path = data_cfg['train']
|
||||
nc = data_cfg['classes'] # number of classes
|
||||
|
||||
# Initialize model
|
||||
model = Darknet(cfg, img_size).to(device)
|
||||
|
@ -145,7 +147,7 @@ def train(
|
|||
# Start training
|
||||
t, t0 = time.time(), time.time()
|
||||
model.hyp = hyp # attach hyperparameters to model
|
||||
model.class_weights = labels_to_class_weights(dataset.labels).to(device) # attach class weights
|
||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||
model_info(model)
|
||||
nb = len(dataloader)
|
||||
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
||||
|
|
|
@ -49,11 +49,11 @@ def model_info(model):
|
|||
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
|
||||
|
||||
|
||||
def labels_to_class_weights(labels):
|
||||
def labels_to_class_weights(labels, nc=80):
|
||||
# Get class weights (inverse frequency) from training labels
|
||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||
classes = labels[:, 0].astype(np.int)
|
||||
weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class
|
||||
weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class
|
||||
weights /= weights.sum()
|
||||
return torch.Tensor(weights)
|
||||
|
||||
|
|
Loading…
Reference in New Issue