updates
This commit is contained in:
parent
2856af5036
commit
be38caf284
5
train.py
5
train.py
|
@ -103,11 +103,10 @@ def train(
|
|||
rloss = defaultdict(float)
|
||||
for i, (imgs, targets, _, _) in enumerate(dataloader):
|
||||
|
||||
if targets.shape[1] == 100: # multithreaded forced to 100
|
||||
if targets.shape[1] == 100: # multithreaded 100-size block
|
||||
targets = targets.view((-1, 6))
|
||||
targets = targets[targets[:, 5].nonzero().squeeze()]
|
||||
|
||||
targets = targets.to(device)
|
||||
nT = targets.shape[0]
|
||||
if nT == 0: # if no targets continue
|
||||
continue
|
||||
|
@ -122,7 +121,7 @@ def train(
|
|||
pred = model(imgs.to(device))
|
||||
|
||||
# Build targets
|
||||
target_list = build_targets(model, targets, pred)
|
||||
target_list = build_targets(model, targets.to(device), pred)
|
||||
|
||||
# Compute loss
|
||||
loss, loss_dict = compute_loss(pred, target_list)
|
||||
|
|
Loading…
Reference in New Issue