updates
This commit is contained in:
parent
5ea92e7ee2
commit
50df252c4b
5
train.py
5
train.py
|
@ -121,10 +121,9 @@ def train(
|
||||||
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
||||||
imgs = imgs.to(device)
|
imgs = imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
nt = len(targets)
|
nt = len(targets)
|
||||||
if nt == 0: # if no targets continue
|
# if nt == 0: # if no targets continue
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
# Plot images with bounding boxes
|
# Plot images with bounding boxes
|
||||||
if epoch == 0 and i == 0:
|
if epoch == 0 and i == 0:
|
||||||
|
|
|
@ -321,7 +321,7 @@ def build_targets(model, targets):
|
||||||
|
|
||||||
# Class
|
# Class
|
||||||
tcls.append(c)
|
tcls.append(c)
|
||||||
if nt:
|
if c.shape[0]:
|
||||||
assert c.max() <= layer.nC, 'Target classes exceed model classes'
|
assert c.max() <= layer.nC, 'Target classes exceed model classes'
|
||||||
|
|
||||||
return txy, twh, tcls, indices
|
return txy, twh, tcls, indices
|
||||||
|
|
Loading…
Reference in New Issue