This commit is contained in:
Glenn Jocher 2019-03-21 12:13:09 +02:00
parent 2856af5036
commit be38caf284
1 changed files with 2 additions and 3 deletions

View File

@ -103,11 +103,10 @@ def train(
rloss = defaultdict(float)
for i, (imgs, targets, _, _) in enumerate(dataloader):
if targets.shape[1] == 100: # multithreaded forced to 100
if targets.shape[1] == 100: # multithreaded 100-size block
targets = targets.view((-1, 6))
targets = targets[targets[:, 5].nonzero().squeeze()]
targets = targets.to(device)
nT = targets.shape[0]
if nT == 0: # if no targets continue
continue
@ -122,7 +121,7 @@ def train(
pred = model(imgs.to(device))
# Build targets
target_list = build_targets(model, targets, pred)
target_list = build_targets(model, targets.to(device), pred)
# Compute loss
loss, loss_dict = compute_loss(pred, target_list)