updates
This commit is contained in:
parent
65abb1c82f
commit
b11bb91804
12
train.py
12
train.py
|
@ -185,7 +185,8 @@ def train(cfg,
|
||||||
batch_size,
|
batch_size,
|
||||||
augment=True,
|
augment=True,
|
||||||
hyp=hyp, # augmentation hyperparameters
|
hyp=hyp, # augmentation hyperparameters
|
||||||
rect=opt.rect) # rectangular training
|
rect=opt.rect, # rectangular training
|
||||||
|
image_weights=False)
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
dataloader = torch.utils.data.DataLoader(dataset,
|
dataloader = torch.utils.data.DataLoader(dataset,
|
||||||
|
@ -219,10 +220,11 @@ def train(cfg,
|
||||||
if int(name.split('.')[1]) < cutoff: # if layer < 75
|
if int(name.split('.')[1]) < cutoff: # if layer < 75
|
||||||
p.requires_grad = False if epoch == 0 else True
|
p.requires_grad = False if epoch == 0 else True
|
||||||
|
|
||||||
# # Update image weights (optional)
|
# Update image weights (optional)
|
||||||
# w = model.class_weights.cpu().numpy() * (1 - maps) # class weights
|
if dataset.image_weights:
|
||||||
# image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
w = model.class_weights.cpu().numpy() * (1 - maps) # class weights
|
||||||
# dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index
|
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
||||||
|
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
|
||||||
|
|
||||||
mloss = torch.zeros(5).to(device) # mean losses
|
mloss = torch.zeros(5).to(device) # mean losses
|
||||||
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
||||||
|
|
|
@ -213,7 +213,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
self.imgs = [None] * n
|
self.imgs = [None] * n
|
||||||
self.labels = [None] * n
|
self.labels = [None] * n
|
||||||
preload_labels = False
|
preload_labels = False
|
||||||
if preload_labels:
|
if preload_labels or image_weights:
|
||||||
self.labels = [np.zeros((0, 5))] * n
|
self.labels = [np.zeros((0, 5))] * n
|
||||||
extract_bounding_boxes = False
|
extract_bounding_boxes = False
|
||||||
for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files):
|
for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files):
|
||||||
|
|
Loading…
Reference in New Issue