updates
This commit is contained in:
parent
6bd51b75ea
commit
5fb661b7d4
2
test.py
2
test.py
|
@ -38,7 +38,7 @@ def test(cfg,
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
load_darknet_weights(model, weights)
|
load_darknet_weights(model, weights)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
else: # called by train.py
|
else: # called by train.py
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
|
|
Loading…
Reference in New Issue