This commit is contained in:
Glenn Jocher 2019-06-24 13:43:17 +02:00
parent 57b616b8b1
commit 0005823d1f
2 changed files with 5 additions and 4 deletions

View File

@ -6,7 +6,7 @@ import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import test # Import test.py to get mAP after each epoch import test # import test.py to get mAP after each epoch
from models import * from models import *
from utils.datasets import * from utils.datasets import *
from utils.utils import * from utils.utils import *
@ -41,7 +41,6 @@ def train(
latest = weights + 'latest.pt' latest = weights + 'latest.pt'
best = weights + 'best.pt' best = weights + 'best.pt'
device = torch_utils.select_device() device = torch_utils.select_device()
torch.backends.cudnn.benchmark = True # possibly unsuitable for multiscale
img_size_test = img_size # image size for testing img_size_test = img_size # image size for testing
multi_scale = not opt.single_scale multi_scale = not opt.single_scale
@ -145,7 +144,7 @@ def train(
# Start training # Start training
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model_info(model) model_info(model, report='summary') # 'full' or 'summary'
nb = len(dataloader) nb = len(dataloader)
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss

View File

@ -5,6 +5,8 @@ def init_seeds(seed=0):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True # set False for reproducible resuls
# torch.backends.cudnn.deterministic = True # https://pytorch.org/docs/stable/notes/randomness.html
def select_device(force_cpu=False): def select_device(force_cpu=False):