This commit is contained in:
Glenn Jocher 2019-01-09 11:48:04 +01:00
parent acfe4aaf94
commit 646a573740
1 changed files with 1 additions and 16 deletions

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from utils.parse_config import *
from utils.utils import *
ONNX_EXPORT = False
ONNX_EXPORT = True
def create_modules(module_defs):
@ -128,7 +128,6 @@ class YOLOLayer(nn.Module):
self.weights = class_weights()
self.loss_means = torch.ones(6)
self.tx, self.ty, self.tw, self.th = [], [], [], []
self.yolo_layer = anchor_idxs[0] / nA # 2, 1, 0
self.stride = stride
@ -205,25 +204,11 @@ class YOLOLayer(nn.Module):
lw = k * MSELoss(w[mask], tw[mask])
lh = k * MSELoss(h[mask], th[mask])
# self.tx.extend(tx[mask].data.numpy())
# self.ty.extend(ty[mask].data.numpy())
# self.tw.extend(tw[mask].data.numpy())
# self.th.extend(th[mask].data.numpy())
# print([np.mean(self.tx), np.std(self.tx)],[np.mean(self.ty), np.std(self.ty)],[np.mean(self.tw), np.std(self.tw)],[np.mean(self.th), np.std(self.th)])
# [0.5040668, 0.2885492] [0.51384246, 0.28328574] [-0.4754091, 0.57951087] [-0.25998235, 0.44858757]
# [0.50184494, 0.2858976] [0.51747805, 0.2896323] [0.12962963, 0.6263085] [-0.2722081, 0.61574113]
# [0.5032071, 0.28825334] [0.5063132, 0.2808862] [0.21124361, 0.44760725] [0.35445485, 0.6427766]
# import matplotlib.pyplot as plt
# plt.hist(self.x)
# lconf = k * BCEWithLogitsLoss(p_conf[mask], mask[mask].float())
lcls = (k / 4) * CrossEntropyLoss(p_cls[mask], torch.argmax(tcls, 1))
# lcls = (k * 10) * BCEWithLogitsLoss(p_cls[mask], tcls.float())
else:
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
# lconf += k * BCEWithLogitsLoss(p_conf[~mask], mask[~mask].float())
lconf = (k * 64) * BCEWithLogitsLoss(p_conf, mask.float())
# Sum loss components