This commit is contained in:
Glenn Jocher 2019-05-10 16:29:37 +02:00
parent ddd0474111
commit 9ffb40b0be
1 changed files with 3 additions and 2 deletions

View File

@ -119,7 +119,7 @@ def train(
# plt.savefig('LR.png', dpi=300) # plt.savefig('LR.png', dpi=300)
# Dataset # Dataset
dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, image_weighting=False) dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, image_weights=False)
# Initialize distributed training # Initialize distributed training
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
@ -167,7 +167,8 @@ def train(
p.requires_grad = False if epoch == 0 else True p.requires_grad = False if epoch == 0 else True
# Update image weights (optional) # Update image weights (optional)
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=1 - maps) w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index
mloss = torch.zeros(5).to(device) # mean losses mloss = torch.zeros(5).to(device) # mean losses