This commit is contained in:
Glenn Jocher 2019-02-27 14:19:57 +01:00
parent 70339798c5
commit e094bb14ba
1 changed files with 4 additions and 7 deletions

View File

@ -1,6 +1,8 @@
import glob
import random import random
import cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -428,7 +430,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
def strip_optimizer_from_checkpoint(filename='weights/best.pt'): def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
# Strip optimizer from *.pt files for lighter files (reduced by 2/3 size) # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
import torch
a = torch.load(filename, map_location='cpu') a = torch.load(filename, map_location='cpu')
a['optimizer'] = [] a['optimizer'] = []
torch.save(a, filename.replace('.pt', '_lite.pt')) torch.save(a, filename.replace('.pt', '_lite.pt'))
@ -436,7 +438,6 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
def coco_class_count(path='../coco/labels/train2014/'): def coco_class_count(path='../coco/labels/train2014/'):
# histogram of occurrences per class # histogram of occurrences per class
import glob
nC = 80 # number classes nC = 80 # number classes
x = np.zeros(nC, dtype='int32') x = np.zeros(nC, dtype='int32')
@ -449,7 +450,6 @@ def coco_class_count(path='../coco/labels/train2014/'):
def coco_only_people(path='../coco/labels/val2014/'): def coco_only_people(path='../coco/labels/val2014/'):
# find images with only people # find images with only people
import glob
files = sorted(glob.glob('%s/*.*' % path)) files = sorted(glob.glob('%s/*.*' % path))
for i, file in enumerate(files): for i, file in enumerate(files):
@ -460,10 +460,7 @@ def coco_only_people(path='../coco/labels/val2014/'):
def plot_results(): def plot_results():
# Plot YOLO training results file 'results.txt' # Plot YOLO training results file 'results.txt'
import glob # import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt'
import matplotlib.pyplot as plt
import numpy as np
# import os; os.system('rm -rf results.txt && wget https://storage.googleapis.com/ultralytics/results_v1_0.txt')
plt.figure(figsize=(14, 7)) plt.figure(figsize=(14, 7))
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision'] s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'mAP', 'Recall', 'Precision']