This commit is contained in:
Glenn Jocher 2019-08-05 16:59:32 +02:00
parent 0c845a2ff0
commit 2195bb0e89
2 changed files with 8 additions and 6 deletions

View File

@ -197,6 +197,7 @@ def train(cfg,
collate_fn=dataset.collate_fn) collate_fn=dataset.collate_fn)
# Start training # Start training
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
if dataset.image_weights: if dataset.image_weights:
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights

View File

@ -315,10 +315,11 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
if model.nc > 1: # cls loss (only if multiple classes)
tclsm = torch.zeros_like(pi[..., 5:]) tclsm = torch.zeros_like(pi[..., 5:])
tclsm[range(nb), tcls[i]] = 1.0 tclsm[range(nb), tcls[i]] = 1.0
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # cls loss (BCE) lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # cls loss (CE) # lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE
# Append targets to text file # Append targets to text file
# with open('targets.txt', 'a') as file: # with open('targets.txt', 'a') as file:
@ -383,7 +384,7 @@ def build_targets(model, targets):
# Class # Class
tcls.append(c) tcls.append(c)
if c.shape[0]: if c.shape[0]: # if any targets
assert c.max() <= layer.nc, 'Target classes exceed model classes' assert c.max() <= layer.nc, 'Target classes exceed model classes'
return txy, twh, tcls, tbox, indices, anchor_vec return txy, twh, tcls, tbox, indices, anchor_vec