updates
This commit is contained in:
parent
1613d1c396
commit
9a9224cfe6
|
@ -333,11 +333,10 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
|
||||||
|
|
||||||
def build_targets(model, targets):
|
def build_targets(model, targets):
|
||||||
# targets = [image, class, x, y, w, h]
|
# targets = [image, class, x, y, w, h]
|
||||||
iou_thres = model.hyp['iou_t'] # hyperparameter
|
|
||||||
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
|
||||||
|
|
||||||
nt = len(targets)
|
nt = len(targets)
|
||||||
txy, twh, tcls, tbox, indices, av = [], [], [], [], [], []
|
txy, twh, tcls, tbox, indices, av = [], [], [], [], [], []
|
||||||
|
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
for i in model.yolo_layers:
|
for i in model.yolo_layers:
|
||||||
# get number of grid points and anchor vec for this yolo layer
|
# get number of grid points and anchor vec for this yolo layer
|
||||||
if multi_gpu:
|
if multi_gpu:
|
||||||
|
@ -364,7 +363,7 @@ def build_targets(model, targets):
|
||||||
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R)
|
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R)
|
||||||
reject = True
|
reject = True
|
||||||
if reject:
|
if reject:
|
||||||
j = iou > iou_thres
|
j = iou > model.hyp['iou_t'] # iou threshold hyperparameter
|
||||||
t, a, gwh = t[j], a[j], gwh[j]
|
t, a, gwh = t[j], a[j], gwh[j]
|
||||||
|
|
||||||
# Indices
|
# Indices
|
||||||
|
|
Loading…
Reference in New Issue