commit
ad49e70f47
4
train.py
4
train.py
|
@ -32,7 +32,7 @@ def train(
|
|||
train_path = parse_data_cfg(data_cfg)['train']
|
||||
|
||||
# Initialize model
|
||||
model = Darknet(cfg, img_size)
|
||||
model = Darknet(cfg, img_size).to(device)
|
||||
|
||||
# Get dataloader
|
||||
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
|
||||
|
@ -43,7 +43,7 @@ def train(
|
|||
start_epoch = 0
|
||||
best_loss = float('inf')
|
||||
if resume:
|
||||
checkpoint = torch.load(latest, map_location='cpu')
|
||||
checkpoint = torch.load(latest, map_location=device)
|
||||
|
||||
# Load weights to resume from
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
|
|
Loading…
Reference in New Issue