multi_gpu multi_scale
This commit is contained in:
parent
feb5fcb16f
commit
76f555c108
6
train.py
6
train.py
|
@ -25,6 +25,8 @@ def train(
|
||||||
|
|
||||||
if multi_scale: # pass maximum multi_scale size
|
if multi_scale: # pass maximum multi_scale size
|
||||||
img_size = 608
|
img_size = 608
|
||||||
|
ms_index = -1
|
||||||
|
ms_sizes = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
|
||||||
else:
|
else:
|
||||||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||||
|
|
||||||
|
@ -153,7 +155,9 @@ def train(
|
||||||
|
|
||||||
# Multi-Scale training (320 - 608 pixels) every 10 batches
|
# Multi-Scale training (320 - 608 pixels) every 10 batches
|
||||||
if multi_scale and (i + 1) % 10 == 0:
|
if multi_scale and (i + 1) % 10 == 0:
|
||||||
dataloader.img_size = random.choice(range(10, 20)) * 32
|
ms_index += 1
|
||||||
|
dataloader.img_size = ms_sizes[ms_index]
|
||||||
|
# dataloader.img_size = random.choice(range(10, 20)) * 32
|
||||||
print('multi_scale img_size = %g' % dataloader.img_size)
|
print('multi_scale img_size = %g' % dataloader.img_size)
|
||||||
|
|
||||||
# Update best loss
|
# Update best loss
|
||||||
|
|
Loading…
Reference in New Issue