updates
This commit is contained in:
parent
65abb1c82f
commit
b11bb91804
12
train.py
12
train.py
|
@ -185,7 +185,8 @@ def train(cfg,
|
|||
batch_size,
|
||||
augment=True,
|
||||
hyp=hyp, # augmentation hyperparameters
|
||||
rect=opt.rect) # rectangular training
|
||||
rect=opt.rect, # rectangular training
|
||||
image_weights=False)
|
||||
|
||||
# Dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset,
|
||||
|
@ -219,10 +220,11 @@ def train(cfg,
|
|||
if int(name.split('.')[1]) < cutoff: # if layer < 75
|
||||
p.requires_grad = False if epoch == 0 else True
|
||||
|
||||
# # Update image weights (optional)
|
||||
# w = model.class_weights.cpu().numpy() * (1 - maps) # class weights
|
||||
# image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
||||
# dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # random weighted index
|
||||
# Update image weights (optional)
|
||||
if dataset.image_weights:
|
||||
w = model.class_weights.cpu().numpy() * (1 - maps) # class weights
|
||||
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
||||
dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
|
||||
|
||||
mloss = torch.zeros(5).to(device) # mean losses
|
||||
pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
|
||||
|
|
|
@ -213,7 +213,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
self.imgs = [None] * n
|
||||
self.labels = [None] * n
|
||||
preload_labels = False
|
||||
if preload_labels:
|
||||
if preload_labels or image_weights:
|
||||
self.labels = [np.zeros((0, 5))] * n
|
||||
extract_bounding_boxes = False
|
||||
for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files):
|
||||
|
|
Loading…
Reference in New Issue