updates
This commit is contained in:
parent
2f256ee274
commit
bbe22dd7b4
3
train.py
3
train.py
|
@ -191,6 +191,7 @@ def train():
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
model.nc = nc # attach number of classes to model
|
model.nc = nc # attach number of classes to model
|
||||||
|
model.arc = opt.arc # attach yolo architecture
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
model_info(model, report='summary') # 'full' or 'summary'
|
model_info(model, report='summary') # 'full' or 'summary'
|
||||||
|
@ -259,7 +260,7 @@ def train():
|
||||||
pred = model(imgs)
|
pred = model(imgs)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
loss, loss_items = compute_loss(pred, targets, model, arc=opt.arc)
|
loss, loss_items = compute_loss(pred, targets, model)
|
||||||
if torch.isnan(loss):
|
if torch.isnan(loss):
|
||||||
print('WARNING: nan loss detected, ending training')
|
print('WARNING: nan loss detected, ending training')
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -312,11 +312,12 @@ class FocalLoss(nn.Module):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(p, targets, model, arc='default'): # predictions, targets, model
|
def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||||
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||||
h = model.hyp # hyperparameters
|
h = model.hyp # hyperparameters
|
||||||
|
arc = model.arc # # (default, uCE, uBCE) detection architectures
|
||||||
|
|
||||||
# Define criteria
|
# Define criteria
|
||||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
||||||
|
@ -354,7 +355,7 @@ def compute_loss(p, targets, model, arc='default'): # predictions, targets, mod
|
||||||
# with open('targets.txt', 'a') as file:
|
# with open('targets.txt', 'a') as file:
|
||||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||||
|
|
||||||
if arc == 'default': # (default, uCE, uBCE) detection architectures
|
if arc == 'default':
|
||||||
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
||||||
|
|
||||||
elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20
|
elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20
|
||||||
|
|
Loading…
Reference in New Issue