updates
This commit is contained in:
parent
ad0860dbe2
commit
8a1d1b76c0
|
@ -11,7 +11,6 @@ device = torch.device('cuda:0' if cuda else 'cpu')
|
|||
parser = argparse.ArgumentParser()
|
||||
# Get data configuration
|
||||
|
||||
# cd yolo && python3 detect.py -secondary_classifier 1
|
||||
parser.add_argument('-image_folder', type=str, default='data/samples', help='path to images')
|
||||
parser.add_argument('-output_folder', type=str, default='output', help='path to outputs')
|
||||
parser.add_argument('-plot_flag', type=bool, default=True)
|
||||
|
|
22
models.py
22
models.py
|
@ -82,7 +82,7 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
self.anchors = anchors
|
||||
self.nA = nA # number of anchors (3)
|
||||
self.nC = nC # number of classes (60)
|
||||
self.nC = nC # number of classes (80)
|
||||
self.bbox_attrs = 5 + nC
|
||||
self.img_dim = img_dim # from hyperparams in cfg file, NOT from parser
|
||||
|
||||
|
@ -103,7 +103,6 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
def forward(self, p, targets=None, requestPrecision=False, epoch=None):
|
||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||
# device = torch.device('cuda:0' if p.is_cuda else 'cpu')
|
||||
|
||||
bs = p.shape[0]
|
||||
nG = p.shape[2]
|
||||
|
@ -112,7 +111,6 @@ class YOLOLayer(nn.Module):
|
|||
if p.is_cuda and not self.grid_x.is_cuda:
|
||||
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
||||
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
|
||||
# self.scaled_anchors = self.scaled_anchors.cuda()
|
||||
|
||||
# x.view(4, 650, 19, 19) -- > (4, 10, 19, 19, 65) # (bs, anchors, grid, grid, classes + xywh)
|
||||
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||
|
@ -132,11 +130,9 @@ class YOLOLayer(nn.Module):
|
|||
|
||||
# Training
|
||||
if targets is not None:
|
||||
BCEWithLogitsLoss1 = nn.BCEWithLogitsLoss(size_average=False) # version 0.4.0
|
||||
BCEWithLogitsLoss0 = nn.BCEWithLogitsLoss()
|
||||
# BCEWithLogitsLoss2 = nn.BCEWithLogitsLoss(size_average=True)
|
||||
MSELoss = nn.MSELoss(size_average=False) # version 0.4.0
|
||||
CrossEntropyLoss = nn.CrossEntropyLoss()
|
||||
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
|
||||
MSELoss = nn.MSELoss() # version 0.4.0
|
||||
# CrossEntropyLoss = nn.CrossEntropyLoss()
|
||||
|
||||
if requestPrecision:
|
||||
gx = self.grid_x[:, :, :nG, :nG]
|
||||
|
@ -154,21 +150,21 @@ class YOLOLayer(nn.Module):
|
|||
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
|
||||
|
||||
# Mask outputs to ignore non-existing objects (but keep confidence predictions)
|
||||
nM = mask.sum().float()
|
||||
nM = mask.sum()
|
||||
nGT = sum([len(x) for x in targets])
|
||||
if nM > 0:
|
||||
lx = 5 * MSELoss(x[mask], tx[mask])
|
||||
ly = 5 * MSELoss(y[mask], ty[mask])
|
||||
lw = 5 * MSELoss(w[mask], tw[mask])
|
||||
lh = 5 * MSELoss(h[mask], th[mask])
|
||||
lconf = 1.5 * BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float())
|
||||
lconf = 1.5 * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float())
|
||||
|
||||
lcls = nM * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
|
||||
# lcls = BCEWithLogitsLoss1(pred_cls[mask], tcls.float())
|
||||
# lcls = CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
|
||||
lcls = BCEWithLogitsLoss(pred_cls[mask], tcls.float())
|
||||
else:
|
||||
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
|
||||
|
||||
lconf += nM * BCEWithLogitsLoss0(pred_conf[~mask], mask[~mask].float())
|
||||
lconf += BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float())
|
||||
|
||||
loss = lx + ly + lw + lh + lconf + lcls
|
||||
i = torch.sigmoid(pred_conf[~mask]) > 0.99
|
||||
|
|
2
train.py
2
train.py
|
@ -94,7 +94,7 @@ def main(opt):
|
|||
epoch += start_epoch
|
||||
|
||||
# img_size = random.choice(range(10, 20)) * 32
|
||||
# dataloader = ListDataset(train_path, batch_size=opt.batch_size, img_size=img_size, targets_path=targets_path)
|
||||
# dataloader = ListDataset(train_path, batch_size=opt.batch_size, img_size=img_size)
|
||||
# print('Running image size %g' % img_size)
|
||||
|
||||
# Update scheduler
|
||||
|
|
Loading…
Reference in New Issue