updates
This commit is contained in:
parent
57b616b8b1
commit
0005823d1f
7
train.py
7
train.py
|
@ -6,7 +6,7 @@ import torch.optim as optim
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import test # Import test.py to get mAP after each epoch
|
import test # import test.py to get mAP after each epoch
|
||||||
from models import *
|
from models import *
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
@ -41,7 +41,6 @@ def train(
|
||||||
latest = weights + 'latest.pt'
|
latest = weights + 'latest.pt'
|
||||||
best = weights + 'best.pt'
|
best = weights + 'best.pt'
|
||||||
device = torch_utils.select_device()
|
device = torch_utils.select_device()
|
||||||
torch.backends.cudnn.benchmark = True # possibly unsuitable for multiscale
|
|
||||||
img_size_test = img_size # image size for testing
|
img_size_test = img_size # image size for testing
|
||||||
multi_scale = not opt.single_scale
|
multi_scale = not opt.single_scale
|
||||||
|
|
||||||
|
@ -145,7 +144,7 @@ def train(
|
||||||
# Start training
|
# Start training
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
model_info(model)
|
model_info(model, report='summary') # 'full' or 'summary'
|
||||||
nb = len(dataloader)
|
nb = len(dataloader)
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
||||||
|
@ -330,7 +329,7 @@ if __name__ == '__main__':
|
||||||
# Mutate hyperparameters
|
# Mutate hyperparameters
|
||||||
old_hyp = hyp.copy()
|
old_hyp = hyp.copy()
|
||||||
init_seeds(seed=int(time.time()))
|
init_seeds(seed=int(time.time()))
|
||||||
s = [.4, .4, .4, .4, .4, .4, .4, .4*0, .4*0, .04*0, .4*0] # fractional sigmas
|
s = [.4, .4, .4, .4, .4, .4, .4, .4 * 0, .4 * 0, .04 * 0, .4 * 0] # fractional sigmas
|
||||||
for i, k in enumerate(hyp.keys()):
|
for i, k in enumerate(hyp.keys()):
|
||||||
x = (np.random.randn(1) * s[i] + 1) ** 1.1 # plt.hist(x.ravel(), 100)
|
x = (np.random.randn(1) * s[i] + 1) ** 1.1 # plt.hist(x.ravel(), 100)
|
||||||
hyp[k] = hyp[k] * float(x) # vary by about 30% 1sigma
|
hyp[k] = hyp[k] * float(x) # vary by about 30% 1sigma
|
||||||
|
|
|
@ -5,6 +5,8 @@ def init_seeds(seed=0):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.benchmark = True # set False for reproducible resuls
|
||||||
|
# torch.backends.cudnn.deterministic = True # https://pytorch.org/docs/stable/notes/randomness.html
|
||||||
|
|
||||||
|
|
||||||
def select_device(force_cpu=False):
|
def select_device(force_cpu=False):
|
||||||
|
|
Loading…
Reference in New Issue