updates
This commit is contained in:
parent
2856af5036
commit
be38caf284
5
train.py
5
train.py
|
@ -103,11 +103,10 @@ def train(
|
||||||
rloss = defaultdict(float)
|
rloss = defaultdict(float)
|
||||||
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
||||||
|
|
||||||
if targets.shape[1] == 100: # multithreaded forced to 100
|
if targets.shape[1] == 100: # multithreaded 100-size block
|
||||||
targets = targets.view((-1, 6))
|
targets = targets.view((-1, 6))
|
||||||
targets = targets[targets[:, 5].nonzero().squeeze()]
|
targets = targets[targets[:, 5].nonzero().squeeze()]
|
||||||
|
|
||||||
targets = targets.to(device)
|
|
||||||
nT = targets.shape[0]
|
nT = targets.shape[0]
|
||||||
if nT == 0: # if no targets continue
|
if nT == 0: # if no targets continue
|
||||||
continue
|
continue
|
||||||
|
@ -122,7 +121,7 @@ def train(
|
||||||
pred = model(imgs.to(device))
|
pred = model(imgs.to(device))
|
||||||
|
|
||||||
# Build targets
|
# Build targets
|
||||||
target_list = build_targets(model, targets, pred)
|
target_list = build_targets(model, targets.to(device), pred)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
loss, loss_dict = compute_loss(pred, target_list)
|
loss, loss_dict = compute_loss(pred, target_list)
|
||||||
|
|
Loading…
Reference in New Issue