This commit is contained in:
Glenn Jocher 2019-06-18 15:34:35 +02:00
parent 58203e49c8
commit 677bdf236c
1 changed files with 2 additions and 13 deletions

View File

@ -11,7 +11,6 @@ from PIL import Image
from tqdm import tqdm
from . import torch_utils
from . import parse_config
matplotlib.rc('font', **{'size': 12})
@ -295,11 +294,9 @@ def compute_loss(p, targets, model, giou_loss=False): # predictions, targets, m
tconf[b, a, gj, gi] = 1 # conf
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
# Build GIoU boxes
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)
if giou_loss:
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
else:
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
@ -490,14 +487,6 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
torch.save(a, filename.replace('.pt', '_lite.pt'))
def extract_bounding_boxes(data_cfg='data/coco_64img.data'): # from utils.utils import *; extract_bounding_boxes()
# Extract bounding boxes into a new classification dataset
data_dict = parse_config.parse_data_cfg(data_cfg)
train_path = data_dict['train']
nc = int(data_dict['classes']) # number of classes
def coco_class_count(path='../coco/labels/train2014/'):
# Histogram of occurrences per class
nc = 80 # number classes