This commit is contained in:
Glenn Jocher 2019-04-27 17:51:59 +02:00
parent 8f1becd55c
commit d25190e15b
2 changed files with 6 additions and 4 deletions

View File

@ -64,7 +64,9 @@ def train(
torch.backends.cudnn.benchmark = True # unsuitable for multiscale torch.backends.cudnn.benchmark = True # unsuitable for multiscale
# Configure run # Configure run
train_path = parse_data_cfg(data_cfg)['train'] data_cfg = parse_data_cfg(data_cfg)
train_path = data_cfg['train']
nc = data_cfg['classes'] # number of classes
# Initialize model # Initialize model
model = Darknet(cfg, img_size).to(device) model = Darknet(cfg, img_size).to(device)
@ -145,7 +147,7 @@ def train(
# Start training # Start training
t, t0 = time.time(), time.time() t, t0 = time.time(), time.time()
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels).to(device) # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model_info(model) model_info(model)
nb = len(dataloader) nb = len(dataloader)
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss

View File

@ -49,11 +49,11 @@ def model_info(model):
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g)) print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
def labels_to_class_weights(labels): def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels # Get class weights (inverse frequency) from training labels
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int) classes = labels[:, 0].astype(np.int)
weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class
weights /= weights.sum() weights /= weights.sum()
return torch.Tensor(weights) return torch.Tensor(weights)