Merge pull request #145 from perry0418/master

Update train.py
This commit is contained in:
Glenn Jocher 2019-03-21 12:04:50 +02:00 committed by GitHub
commit ad49e70f47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -32,7 +32,7 @@ def train(
train_path = parse_data_cfg(data_cfg)['train'] train_path = parse_data_cfg(data_cfg)['train']
# Initialize model # Initialize model
model = Darknet(cfg, img_size) model = Darknet(cfg, img_size).to(device)
# Get dataloader # Get dataloader
dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True) dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True)
@ -43,7 +43,7 @@ def train(
start_epoch = 0 start_epoch = 0
best_loss = float('inf') best_loss = float('inf')
if resume: if resume:
checkpoint = torch.load(latest, map_location='cpu') checkpoint = torch.load(latest, map_location=device)
# Load weights to resume from # Load weights to resume from
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])