Update train.py
solve the Multi-GPU --resume Error #138 https://github.com/ultralytics/yolov3/issues/138
This commit is contained in:
parent
2cd6805063
commit
35396adc9c
4
train.py
4
train.py
|
@ -32,7 +32,7 @@ def train(
|
||||||
train_path = parse_data_cfg(data_cfg)['train']
|
train_path = parse_data_cfg(data_cfg)['train']
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg, img_size)
|
model = Darknet(cfg, img_size).to(device)
|
||||||
|
|
||||||
# Get dataloader
|
# Get dataloader
|
||||||
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
|
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
|
||||||
|
@ -43,7 +43,7 @@ def train(
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
if resume:
|
if resume:
|
||||||
checkpoint = torch.load(latest, map_location='cpu')
|
checkpoint = torch.load(latest, map_location=device)
|
||||||
|
|
||||||
# Load weights to resume from
|
# Load weights to resume from
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
|
Loading…
Reference in New Issue