check_git_status() to train.py

This commit is contained in:
Glenn Jocher 2020-04-22 11:02:09 -07:00
parent 2f636d5740
commit 77b3829d56
2 changed files with 9 additions and 9 deletions

View File

@ -23,7 +23,6 @@ best = wdir + 'best.pt'
results_file = 'results.txt'
# Hyperparameters https://github.com/ultralytics/yolov3/issues/310
hyp = {'giou': 3.54, # giou loss gain
'cls': 37.4, # cls loss gain
'cls_pw': 1.0, # cls BCELoss positive_weight
@ -54,7 +53,6 @@ if f:
if hyp['fl_gamma']:
print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
def train():
cfg = opt.cfg
data = opt.data
@ -408,6 +406,7 @@ if __name__ == '__main__':
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
opt = parser.parse_args()
opt.weights = last if opt.resume else opt.weights
check_git_status()
print(opt)
opt.img_size.extend([opt.img_size[-1]] * (3 - len(opt.img_size))) # extend to 3 sizes (min, max, test)
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)

View File

@ -17,16 +17,10 @@ from tqdm import tqdm
from . import torch_utils # , google_utils
matplotlib.rc('font', **{'size': 11})
# Suggest 'git pull'
s = subprocess.check_output('if [ -d .git ]; then git status -uno; fi', shell=True).decode('utf-8')
if 'Your branch is behind' in s:
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
# Set printoptions
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
matplotlib.rc('font', **{'size': 11})
# Prevent OpenCV from multithreading (to use PyTorch DataLoader)
cv2.setNumThreads(0)
@ -38,6 +32,13 @@ def init_seeds(seed=0):
torch_utils.init_seeds(seed=seed)
def check_git_status():
# Suggest 'git pull' if repo is out of date
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
if 'Your branch is behind' in s:
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
def load_classes(path):
# Loads *.names file at 'path'
with open(path, 'r') as f: