updates
This commit is contained in:
parent
b87bfa32c3
commit
3f06fe6b12
6
test.py
6
test.py
|
@ -24,6 +24,10 @@ def test(cfg,
|
||||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
||||||
verbose = True
|
verbose = True
|
||||||
|
|
||||||
|
# Remove previous
|
||||||
|
for f in glob.glob('test_batch*.jpg'):
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg, img_size).to(device)
|
model = Darknet(cfg, img_size).to(device)
|
||||||
|
|
||||||
|
@ -36,7 +40,7 @@ def test(cfg,
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
else:
|
else: # called by train.py
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
verbose = False
|
verbose = False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue