From b11bb91804bfe13d63fafed69dbbb8011198441a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 30 Jul 2019 17:51:19 +0200 Subject: [PATCH] updates --- train.py | 12 +++++++----- utils/datasets.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 79eb3b08..ae68dca2 100644 --- a/train.py +++ b/train.py @@ -185,7 +185,8 @@ def train(cfg, batch_size, augment=True, hyp=hyp, # augmentation hyperparameters - rect=opt.rect) # rectangular training + rect=opt.rect, # rectangular training + image_weights=False) # Dataloader dataloader = torch.utils.data.DataLoader(dataset, @@ -219,10 +220,11 @@ def train(cfg, if int(name.split('.')[1]) < cutoff: # if layer < 75 p.requires_grad = False if epoch == 0 else True - # # Update image weights (optional) - # w = model.class_weights.cpu().numpy() * (1 - maps) # class weights - # image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) - # dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index + # Update image weights (optional) + if dataset.image_weights: + w = model.class_weights.cpu().numpy() * (1 - maps) # class weights + image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) + dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx mloss = torch.zeros(5).to(device) # mean losses pbar = tqdm(enumerate(dataloader), total=nb) # progress bar diff --git a/utils/datasets.py b/utils/datasets.py index 09295fdc..f35671ba 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -213,7 +213,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing self.imgs = [None] * n self.labels = [None] * n preload_labels = False - if preload_labels: + if preload_labels or image_weights: self.labels = [np.zeros((0, 5))] * n extract_bounding_boxes = False for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files):