update test.py
This commit is contained in:
parent
207a17de31
commit
ca7794ed05
6
test.py
6
test.py
|
@ -23,6 +23,7 @@ def test(cfg,
|
|||
multi_label=True):
|
||||
# Initialize/load model and set device
|
||||
if model is None:
|
||||
is_training = False
|
||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
||||
verbose = opt.task == 'test'
|
||||
|
||||
|
@ -47,6 +48,7 @@ def test(cfg,
|
|||
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
else: # called by train.py
|
||||
is_training = True
|
||||
device = next(model.parameters()).device # get model device
|
||||
verbose = False
|
||||
|
||||
|
@ -61,7 +63,7 @@ def test(cfg,
|
|||
|
||||
# Dataloader
|
||||
if dataloader is None:
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size, rect=True, single_cls=opt.single_cls)
|
||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size, rect=True, single_cls=opt.single_cls, pad=0.5)
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
|
@ -91,7 +93,7 @@ def test(cfg,
|
|||
t0 += torch_utils.time_synchronized() - t
|
||||
|
||||
# Compute loss
|
||||
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
||||
if is_training: # if model has loss hyperparameters
|
||||
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
||||
|
||||
# Run NMS
|
||||
|
|
Loading…
Reference in New Issue