This commit is contained in:
Glenn Jocher 2019-07-30 17:51:19 +02:00
parent 65abb1c82f
commit b11bb91804
2 changed files with 8 additions and 6 deletions

View File

@ -185,7 +185,8 @@ def train(cfg,
batch_size, batch_size,
augment=True, augment=True,
hyp=hyp, # augmentation hyperparameters hyp=hyp, # augmentation hyperparameters
rect=opt.rect) # rectangular training rect=opt.rect, # rectangular training
image_weights=False)
# Dataloader # Dataloader
dataloader = torch.utils.data.DataLoader(dataset, dataloader = torch.utils.data.DataLoader(dataset,
@ -219,10 +220,11 @@ def train(cfg,
if int(name.split('.')[1]) < cutoff: # if layer < 75 if int(name.split('.')[1]) < cutoff: # if layer < 75
p.requires_grad = False if epoch == 0 else True p.requires_grad = False if epoch == 0 else True
# # Update image weights (optional) # Update image weights (optional)
# w = model.class_weights.cpu().numpy() * (1 - maps) # class weights if dataset.image_weights:
# image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) w = model.class_weights.cpu().numpy() * (1 - maps) # class weights
# dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
mloss = torch.zeros(5).to(device) # mean losses mloss = torch.zeros(5).to(device) # mean losses
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar pbar = tqdm(enumerate(dataloader), total=nb) # progress bar

View File

@ -213,7 +213,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
self.imgs = [None] * n self.imgs = [None] * n
self.labels = [None] * n self.labels = [None] * n
preload_labels = False preload_labels = False
if preload_labels: if preload_labels or image_weights:
self.labels = [np.zeros((0, 5))] * n self.labels = [np.zeros((0, 5))] * n
extract_bounding_boxes = False extract_bounding_boxes = False
for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files): for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files):