update test.py
This commit is contained in:
parent
207a17de31
commit
ca7794ed05
6
test.py
6
test.py
|
@ -23,6 +23,7 @@ def test(cfg,
|
||||||
multi_label=True):
|
multi_label=True):
|
||||||
# Initialize/load model and set device
|
# Initialize/load model and set device
|
||||||
if model is None:
|
if model is None:
|
||||||
|
is_training = False
|
||||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
||||||
verbose = opt.task == 'test'
|
verbose = opt.task == 'test'
|
||||||
|
|
||||||
|
@ -47,6 +48,7 @@ def test(cfg,
|
||||||
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
else: # called by train.py
|
else: # called by train.py
|
||||||
|
is_training = True
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
verbose = False
|
verbose = False
|
||||||
|
|
||||||
|
@ -61,7 +63,7 @@ def test(cfg,
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
if dataloader is None:
|
if dataloader is None:
|
||||||
dataset = LoadImagesAndLabels(path, imgsz, batch_size, rect=True, single_cls=opt.single_cls)
|
dataset = LoadImagesAndLabels(path, imgsz, batch_size, rect=True, single_cls=opt.single_cls, pad=0.5)
|
||||||
batch_size = min(batch_size, len(dataset))
|
batch_size = min(batch_size, len(dataset))
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -91,7 +93,7 @@ def test(cfg,
|
||||||
t0 += torch_utils.time_synchronized() - t
|
t0 += torch_utils.time_synchronized() - t
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
if hasattr(model, 'hyp'): # if model has loss hyperparameters
|
if is_training: # if model has loss hyperparameters
|
||||||
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
loss += compute_loss(train_out, targets, model)[1][:3] # GIoU, obj, cls
|
||||||
|
|
||||||
# Run NMS
|
# Run NMS
|
||||||
|
|
Loading…
Reference in New Issue