From be38caf2841c1614a850224b4861101a86f24ab3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 21 Mar 2019 12:13:09 +0200 Subject: [PATCH] updates --- train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index f37bae87..9efbb7f4 100644 --- a/train.py +++ b/train.py @@ -103,11 +103,10 @@ def train( rloss = defaultdict(float) for i, (imgs, targets, _, _) in enumerate(dataloader): - if targets.shape[1] == 100: # multithreaded forced to 100 + if targets.shape[1] == 100: # multithreaded 100-size block targets = targets.view((-1, 6)) targets = targets[targets[:, 5].nonzero().squeeze()] - targets = targets.to(device) nT = targets.shape[0] if nT == 0: # if no targets continue continue @@ -122,7 +121,7 @@ def train( pred = model(imgs.to(device)) # Build targets - target_list = build_targets(model, targets, pred) + target_list = build_targets(model, targets.to(device), pred) # Compute loss loss, loss_dict = compute_loss(pred, target_list)