This commit is contained in:
Glenn Jocher 2019-07-30 17:51:19 +02:00
parent 65abb1c82f
commit b11bb91804
2 changed files with 8 additions and 6 deletions

View File

@ -185,7 +185,8 @@ def train(cfg,
batch_size,
augment=True,
hyp=hyp, # augmentation hyperparameters
rect=opt.rect) # rectangular training
rect=opt.rect, # rectangular training
image_weights=False)
# Dataloader
dataloader = torch.utils.data.DataLoader(dataset,
@ -219,10 +220,11 @@ def train(cfg,
if int(name.split('.')[1]) < cutoff: # if layer < 75
p.requires_grad = False if epoch == 0 else True
# # Update image weights (optional)
# w = model.class_weights.cpu().numpy() * (1 - maps) # class weights
# image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
# dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index
# Update image weights (optional)
if dataset.image_weights:
w = model.class_weights.cpu().numpy() * (1 - maps) # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
mloss = torch.zeros(5).to(device) # mean losses
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar

View File

@ -213,7 +213,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.imgs = [None] * n
self.labels = [None] * n
preload_labels = False
if preload_labels:
if preload_labels or image_weights:
self.labels = [np.zeros((0, 5))] * n
extract_bounding_boxes = False
for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files):