multi_gpu multi_scale
This commit is contained in:
parent
00e181a55a
commit
dcdd1ae6b7
20
models.py
20
models.py
|
@ -106,22 +106,19 @@ class YOLOLayer(nn.Module):
|
|||
self.nC = nC # number of classes (80)
|
||||
self.img_size = 0
|
||||
|
||||
# if ONNX_EXPORT: # grids must be computed in __init__
|
||||
stride = [32, 16, 8][yolo_layer] # stride of this layer
|
||||
if cfg.endswith('yolov3-tiny.cfg'):
|
||||
stride *= 2
|
||||
if ONNX_EXPORT: # grids must be computed in __init__
|
||||
stride = [32, 16, 8][yolo_layer] # stride of this layer
|
||||
if cfg.endswith('yolov3-tiny.cfg'):
|
||||
stride *= 2
|
||||
|
||||
nG = int(img_size / stride) # number grid points
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
create_grids(self, img_size, nG, device)
|
||||
nG = int(img_size / stride) # number grid points
|
||||
create_grids(self, img_size, nG)
|
||||
|
||||
def forward(self, p, img_size, var=None):
|
||||
if ONNX_EXPORT:
|
||||
bs, nG = 1, self.nG # batch size, grid size
|
||||
else:
|
||||
bs, nG = p.shape[0], p.shape[-1]
|
||||
|
||||
if self.img_size != img_size:
|
||||
create_grids(self, img_size, nG, p.device)
|
||||
|
||||
|
@ -214,7 +211,8 @@ def get_yolo_layers(model):
|
|||
return [i for i, x in enumerate(a) if x] # [82, 94, 106] for yolov3
|
||||
|
||||
|
||||
def create_grids(self, img_size, nG, device):
|
||||
def create_grids(self, img_size, nG, device='cpu'):
|
||||
self.img_size = img_size
|
||||
self.stride = img_size / nG
|
||||
|
||||
# build xy offsets
|
||||
|
@ -225,7 +223,7 @@ def create_grids(self, img_size, nG, device):
|
|||
# build wh gains
|
||||
self.anchor_vec = self.anchors.to(device) / self.stride
|
||||
self.anchor_wh = self.anchor_vec.view(1, self.nA, 1, 1, 2).to(device)
|
||||
self.nG = torch.FloatTensor([nG]).to(device)
|
||||
self.nG = torch.Tensor([nG], device=device)
|
||||
|
||||
|
||||
def load_darknet_weights(self, weights, cutoff=-1):
|
||||
|
|
10
train.py
10
train.py
|
@ -23,10 +23,8 @@ def train(
|
|||
best = weights + 'best.pt'
|
||||
device = torch_utils.select_device()
|
||||
|
||||
if multi_scale: # pass maximum multi_scale size
|
||||
img_size = 608
|
||||
ms_index = -1
|
||||
ms_sizes = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
|
||||
if multi_scale:
|
||||
img_size = 608 # initiate with maximum multi_scale size
|
||||
else:
|
||||
torch.backends.cudnn.benchmark = True # unsuitable for multiscale
|
||||
|
||||
|
@ -155,9 +153,7 @@ def train(
|
|||
|
||||
# Multi-Scale training (320 - 608 pixels) every 10 batches
|
||||
if multi_scale and (i + 1) % 10 == 0:
|
||||
ms_index += 1
|
||||
dataloader.img_size = ms_sizes[ms_index]
|
||||
# dataloader.img_size = random.choice(range(10, 20)) * 32
|
||||
dataloader.img_size = random.choice(range(10, 20)) * 32
|
||||
print('multi_scale img_size = %g' % dataloader.img_size)
|
||||
|
||||
# Update best loss
|
||||
|
|
Loading…
Reference in New Issue