This commit is contained in:
Glenn Jocher 2019-03-21 12:13:09 +02:00
parent 2856af5036
commit be38caf284
1 changed files with 2 additions and 3 deletions

View File

@ -103,11 +103,10 @@ def train(
rloss = defaultdict(float) rloss = defaultdict(float)
for i, (imgs, targets, _, _) in enumerate(dataloader): for i, (imgs, targets, _, _) in enumerate(dataloader):
if targets.shape[1] == 100: # multithreaded forced to 100 if targets.shape[1] == 100: # multithreaded 100-size block
targets = targets.view((-1, 6)) targets = targets.view((-1, 6))
targets = targets[targets[:, 5].nonzero().squeeze()] targets = targets[targets[:, 5].nonzero().squeeze()]
targets = targets.to(device)
nT = targets.shape[0] nT = targets.shape[0]
if nT == 0: # if no targets continue if nT == 0: # if no targets continue
continue continue
@ -122,7 +121,7 @@ def train(
pred = model(imgs.to(device)) pred = model(imgs.to(device))
# Build targets # Build targets
target_list = build_targets(model, targets, pred) target_list = build_targets(model, targets.to(device), pred)
# Compute loss # Compute loss
loss, loss_dict = compute_loss(pred, target_list) loss, loss_dict = compute_loss(pred, target_list)