diff --git a/train.py b/train.py index 8234a0ff..d24a027d 100644 --- a/train.py +++ b/train.py @@ -197,6 +197,7 @@ def train(cfg, collate_fn=dataset.collate_fn) # Start training + model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model if dataset.image_weights: model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights diff --git a/utils/utils.py b/utils/utils.py index b4d52f41..65dbf917 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -315,10 +315,11 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss - tclsm = torch.zeros_like(pi[..., 5:]) - tclsm[range(nb), tcls[i]] = 1.0 - lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # cls loss (BCE) - # lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # cls loss (CE) + if model.nc > 1: # cls loss (only if multiple classes) + tclsm = torch.zeros_like(pi[..., 5:]) + tclsm[range(nb), tcls[i]] = 1.0 + lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE + # lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE # Append targets to text file # with open('targets.txt', 'a') as file: @@ -383,7 +384,7 @@ def build_targets(model, targets): # Class tcls.append(c) - if c.shape[0]: + if c.shape[0]: # if any targets assert c.max() <= layer.nc, 'Target classes exceed model classes' return txy, twh, tcls, tbox, indices, anchor_vec @@ -777,7 +778,7 @@ def plot_results_overlay(start=1, stop=0): # from utils.utils import *; plot_re fig, ax = plt.subplots(1, 5, figsize=(14, 3.5)) ax = ax.ravel() for i in range(5): - for j in [i, i+5]: + for j in [i, i + 5]: ax[i].plot(x, results[j, x], marker='.', label=s[j]) ax[i].set_title(t[i]) ax[i].legend()