check_git_status() to train.py
This commit is contained in:
parent
2f636d5740
commit
77b3829d56
3
train.py
3
train.py
|
@ -23,7 +23,6 @@ best = wdir + 'best.pt'
|
||||||
results_file = 'results.txt'
|
results_file = 'results.txt'
|
||||||
|
|
||||||
# Hyperparameters https://github.com/ultralytics/yolov3/issues/310
|
# Hyperparameters https://github.com/ultralytics/yolov3/issues/310
|
||||||
|
|
||||||
hyp = {'giou': 3.54, # giou loss gain
|
hyp = {'giou': 3.54, # giou loss gain
|
||||||
'cls': 37.4, # cls loss gain
|
'cls': 37.4, # cls loss gain
|
||||||
'cls_pw': 1.0, # cls BCELoss positive_weight
|
'cls_pw': 1.0, # cls BCELoss positive_weight
|
||||||
|
@ -54,7 +53,6 @@ if f:
|
||||||
if hyp['fl_gamma']:
|
if hyp['fl_gamma']:
|
||||||
print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
|
print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
cfg = opt.cfg
|
cfg = opt.cfg
|
||||||
data = opt.data
|
data = opt.data
|
||||||
|
@ -408,6 +406,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.weights = last if opt.resume else opt.weights
|
opt.weights = last if opt.resume else opt.weights
|
||||||
|
check_git_status()
|
||||||
print(opt)
|
print(opt)
|
||||||
opt.img_size.extend([opt.img_size[-1]] * (3 - len(opt.img_size))) # extend to 3 sizes (min, max, test)
|
opt.img_size.extend([opt.img_size[-1]] * (3 - len(opt.img_size))) # extend to 3 sizes (min, max, test)
|
||||||
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
||||||
|
|
|
@ -17,16 +17,10 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from . import torch_utils # , google_utils
|
from . import torch_utils # , google_utils
|
||||||
|
|
||||||
matplotlib.rc('font', **{'size': 11})
|
|
||||||
|
|
||||||
# Suggest 'git pull'
|
|
||||||
s = subprocess.check_output('if [ -d .git ]; then git status -uno; fi', shell=True).decode('utf-8')
|
|
||||||
if 'Your branch is behind' in s:
|
|
||||||
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
|
|
||||||
|
|
||||||
# Set printoptions
|
# Set printoptions
|
||||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||||
|
matplotlib.rc('font', **{'size': 11})
|
||||||
|
|
||||||
# Prevent OpenCV from multithreading (to use PyTorch DataLoader)
|
# Prevent OpenCV from multithreading (to use PyTorch DataLoader)
|
||||||
cv2.setNumThreads(0)
|
cv2.setNumThreads(0)
|
||||||
|
@ -38,6 +32,13 @@ def init_seeds(seed=0):
|
||||||
torch_utils.init_seeds(seed=seed)
|
torch_utils.init_seeds(seed=seed)
|
||||||
|
|
||||||
|
|
||||||
|
def check_git_status():
|
||||||
|
# Suggest 'git pull' if repo is out of date
|
||||||
|
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
|
||||||
|
if 'Your branch is behind' in s:
|
||||||
|
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
|
||||||
|
|
||||||
|
|
||||||
def load_classes(path):
|
def load_classes(path):
|
||||||
# Loads *.names file at 'path'
|
# Loads *.names file at 'path'
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
|
|
Loading…
Reference in New Issue