updates
This commit is contained in:
parent
58203e49c8
commit
677bdf236c
|
@ -11,7 +11,6 @@ from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import torch_utils
|
from . import torch_utils
|
||||||
from . import parse_config
|
|
||||||
|
|
||||||
matplotlib.rc('font', **{'size': 12})
|
matplotlib.rc('font', **{'size': 12})
|
||||||
|
|
||||||
|
@ -295,11 +294,9 @@ def compute_loss(p, targets, model, giou_loss=False): # predictions, targets, m
|
||||||
tconf[b, a, gj, gi] = 1 # conf
|
tconf[b, a, gj, gi] = 1 # conf
|
||||||
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
|
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
|
||||||
|
|
||||||
# Build GIoU boxes
|
|
||||||
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted box
|
|
||||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)
|
|
||||||
|
|
||||||
if giou_loss:
|
if giou_loss:
|
||||||
|
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
|
||||||
|
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||||
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
|
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
|
||||||
else:
|
else:
|
||||||
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
|
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
|
||||||
|
@ -490,14 +487,6 @@ def strip_optimizer_from_checkpoint(filename='weights/best.pt'):
|
||||||
torch.save(a, filename.replace('.pt', '_lite.pt'))
|
torch.save(a, filename.replace('.pt', '_lite.pt'))
|
||||||
|
|
||||||
|
|
||||||
def extract_bounding_boxes(data_cfg='data/coco_64img.data'): # from utils.utils import *; extract_bounding_boxes()
|
|
||||||
# Extract bounding boxes into a new classification dataset
|
|
||||||
data_dict = parse_config.parse_data_cfg(data_cfg)
|
|
||||||
train_path = data_dict['train']
|
|
||||||
nc = int(data_dict['classes']) # number of classes
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def coco_class_count(path='../coco/labels/train2014/'):
|
def coco_class_count(path='../coco/labels/train2014/'):
|
||||||
# Histogram of occurrences per class
|
# Histogram of occurrences per class
|
||||||
nc = 80 # number classes
|
nc = 80 # number classes
|
||||||
|
|
Loading…
Reference in New Issue