updates
This commit is contained in:
parent
58203e49c8
commit
677bdf236c
|
@ -11,7 +11,6 @@ from PIL import Image
|
|||
from tqdm import tqdm
|
||||
|
||||
from . import torch_utils
|
||||
from . import parse_config
|
||||
|
||||
matplotlib.rc('font', **{'size': 12})
|
||||
|
||||
|
@ -295,11 +294,9 @@ def compute_loss(p, targets, model, giou_loss=False): # predictions, targets, m
|
|||
tconf[b, a, gj, gi] = 1 # conf
|
||||
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
|
||||
|
||||
# Build GIoU boxes
|
||||
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)
|
||||
|
||||
if giou_loss:
|
||||
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
|
||||
else:
|
||||
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
|
||||
|
@ -490,14 +487,6 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
|
|||
torch.save(a, filename.replace('.pt', '_lite.pt'))
|
||||
|
||||
|
||||
def extract_bounding_boxes(data_cfg='data/coco_64img.data'): # from utils.utils import *; extract_bounding_boxes()
|
||||
# Extract bounding boxes into a new classification dataset
|
||||
data_dict = parse_config.parse_data_cfg(data_cfg)
|
||||
train_path = data_dict['train']
|
||||
nc = int(data_dict['classes']) # number of classes
|
||||
|
||||
|
||||
|
||||
def coco_class_count(path='../coco/labels/train2014/'):
|
||||
# Histogram of occurrences per class
|
||||
nc = 80 # number classes
|
||||
|
|
Loading…
Reference in New Issue