diff --git a/train.py b/train.py index b791d7e4..7ec5a46c 100644 --- a/train.py +++ b/train.py @@ -32,7 +32,7 @@ def train( train_path = parse_data_cfg(data_cfg)['train'] # Initialize model - model = Darknet(cfg, img_size) + model = Darknet(cfg, img_size).to(device) # Get dataloader dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True) @@ -43,7 +43,7 @@ def train( start_epoch = 0 best_loss = float('inf') if resume: - checkpoint = torch.load(latest, map_location='cpu') + checkpoint = torch.load(latest, map_location=device) # Load weights to resume from model.load_state_dict(checkpoint['model'])