From 9ffb40b0be8698dadce4852d9f9eceeee169d366 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 10 May 2019 16:29:37 +0200 Subject: [PATCH] updates --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 4fd25e3e..edef2a0d 100644 --- a/train.py +++ b/train.py @@ -119,7 +119,7 @@ def train( # plt.savefig('LR.png', dpi=300) # Dataset - dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, image_weighting=False) + dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, image_weights=False) # Initialize distributed training if torch.cuda.device_count() > 1: @@ -167,7 +167,8 @@ def train( p.requires_grad = False if epoch == 0 else True # Update image weights (optional) - image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=1 - maps) + w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights + image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index mloss = torch.zeros(5).to(device) # mean losses