updates
This commit is contained in:
parent
0c845a2ff0
commit
2195bb0e89
1
train.py
1
train.py
|
@ -197,6 +197,7 @@ def train(cfg,
|
|||
collate_fn=dataset.collate_fn)
|
||||
|
||||
# Start training
|
||||
model.nc = nc # attach number of classes to model
|
||||
model.hyp = hyp # attach hyperparameters to model
|
||||
if dataset.image_weights:
|
||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||
|
|
|
@ -315,10 +315,11 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
|
|||
lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
|
||||
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
|
||||
|
||||
tclsm = torch.zeros_like(pi[..., 5:])
|
||||
tclsm[range(nb), tcls[i]] = 1.0
|
||||
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # cls loss (BCE)
|
||||
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # cls loss (CE)
|
||||
if model.nc > 1: # cls loss (only if multiple classes)
|
||||
tclsm = torch.zeros_like(pi[..., 5:])
|
||||
tclsm[range(nb), tcls[i]] = 1.0
|
||||
lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm) # BCE
|
||||
# lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # CE
|
||||
|
||||
# Append targets to text file
|
||||
# with open('targets.txt', 'a') as file:
|
||||
|
@ -383,7 +384,7 @@ def build_targets(model, targets):
|
|||
|
||||
# Class
|
||||
tcls.append(c)
|
||||
if c.shape[0]:
|
||||
if c.shape[0]: # if any targets
|
||||
assert c.max() <= layer.nc, 'Target classes exceed model classes'
|
||||
|
||||
return txy, twh, tcls, tbox, indices, anchor_vec
|
||||
|
@ -777,7 +778,7 @@ def plot_results_overlay(start=1, stop=0): # from utils.utils import *; plot_re
|
|||
fig, ax = plt.subplots(1, 5, figsize=(14, 3.5))
|
||||
ax = ax.ravel()
|
||||
for i in range(5):
|
||||
for j in [i, i+5]:
|
||||
for j in [i, i + 5]:
|
||||
ax[i].plot(x, results[j, x], marker='.', label=s[j])
|
||||
ax[i].set_title(t[i])
|
||||
ax[i].legend()
|
||||
|
|
Loading…
Reference in New Issue