This commit is contained in:
Glenn Jocher 2019-08-05 16:59:32 +02:00
parent 0c845a2ff0
commit 2195bb0e89
2 changed files with 8 additions and 6 deletions

View File

@ -197,6 +197,7 @@ def train(cfg,
collate_fn=dataset.collate_fn)
# Start training
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
if dataset.image_weights:
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights

View File

@ -315,10 +315,11 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
if model.nc > 1: # cls loss (only if multiple classes)
tclsm = torch.zeros_like(pi[..., 5:])
tclsm[range(nb), tcls[i]] = 1.0
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # cls loss (BCE)
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # cls loss (CE)
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE
# Append targets to text file
# with open('targets.txt', 'a') as file:
@ -383,7 +384,7 @@ def build_targets(model, targets):
# Class
tcls.append(c)
if c.shape[0]:
if c.shape[0]: # if any targets
assert c.max() <= layer.nc, 'Target classes exceed model classes'
return txy, twh, tcls, tbox, indices, anchor_vec
@ -777,7 +778,7 @@ def plot_results_overlay(start=1, stop=0): # from utils.utils import *; plot_re
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5))
ax = ax.ravel()
for i in range(5):
for j in [i, i+5]:
for j in [i, i + 5]:
ax[i].plot(x, results[j, x], marker='.', label=s[j])
ax[i].set_title(t[i])
ax[i].legend()