updates
This commit is contained in:
parent
0c845a2ff0
commit
2195bb0e89
1
train.py
1
train.py
|
@ -197,6 +197,7 @@ def train(cfg,
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
|
model.nc = nc # attach number of classes to model
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
if dataset.image_weights:
|
if dataset.image_weights:
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
|
|
|
@ -315,10 +315,11 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
|
||||||
lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
|
lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
|
||||||
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
|
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
|
||||||
|
|
||||||
|
if model.nc > 1: # cls loss (only if multiple classes)
|
||||||
tclsm = torch.zeros_like(pi[..., 5:])
|
tclsm = torch.zeros_like(pi[..., 5:])
|
||||||
tclsm[range(nb), tcls[i]] = 1.0
|
tclsm[range(nb), tcls[i]] = 1.0
|
||||||
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # cls loss (BCE)
|
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE
|
||||||
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # cls loss (CE)
|
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE
|
||||||
|
|
||||||
# Append targets to text file
|
# Append targets to text file
|
||||||
# with open('targets.txt', 'a') as file:
|
# with open('targets.txt', 'a') as file:
|
||||||
|
@ -383,7 +384,7 @@ def build_targets(model, targets):
|
||||||
|
|
||||||
# Class
|
# Class
|
||||||
tcls.append(c)
|
tcls.append(c)
|
||||||
if c.shape[0]:
|
if c.shape[0]: # if any targets
|
||||||
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
||||||
|
|
||||||
return txy, twh, tcls, tbox, indices, anchor_vec
|
return txy, twh, tcls, tbox, indices, anchor_vec
|
||||||
|
|
Loading…
Reference in New Issue