updates
This commit is contained in:
parent
ad0860dbe2
commit
8a1d1b76c0
|
@ -11,7 +11,6 @@ device = torch.device('cuda:0' if cuda else 'cpu')
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# Get data configuration
|
# Get data configuration
|
||||||
|
|
||||||
# cd yolo && python3 detect.py -secondary_classifier 1
|
|
||||||
parser.add_argument('-image_folder', type=str, default='data/samples', help='path to images')
|
parser.add_argument('-image_folder', type=str, default='data/samples', help='path to images')
|
||||||
parser.add_argument('-output_folder', type=str, default='output', help='path to outputs')
|
parser.add_argument('-output_folder', type=str, default='output', help='path to outputs')
|
||||||
parser.add_argument('-plot_flag', type=bool, default=True)
|
parser.add_argument('-plot_flag', type=bool, default=True)
|
||||||
|
|
22
models.py
22
models.py
|
@ -82,7 +82,7 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
self.anchors = anchors
|
self.anchors = anchors
|
||||||
self.nA = nA # number of anchors (3)
|
self.nA = nA # number of anchors (3)
|
||||||
self.nC = nC # number of classes (60)
|
self.nC = nC # number of classes (80)
|
||||||
self.bbox_attrs = 5 + nC
|
self.bbox_attrs = 5 + nC
|
||||||
self.img_dim = img_dim # from hyperparams in cfg file, NOT from parser
|
self.img_dim = img_dim # from hyperparams in cfg file, NOT from parser
|
||||||
|
|
||||||
|
@ -103,7 +103,6 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
def forward(self, p, targets=None, requestPrecision=False, epoch=None):
|
def forward(self, p, targets=None, requestPrecision=False, epoch=None):
|
||||||
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
|
||||||
# device = torch.device('cuda:0' if p.is_cuda else 'cpu')
|
|
||||||
|
|
||||||
bs = p.shape[0]
|
bs = p.shape[0]
|
||||||
nG = p.shape[2]
|
nG = p.shape[2]
|
||||||
|
@ -112,7 +111,6 @@ class YOLOLayer(nn.Module):
|
||||||
if p.is_cuda and not self.grid_x.is_cuda:
|
if p.is_cuda and not self.grid_x.is_cuda:
|
||||||
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
self.grid_x, self.grid_y = self.grid_x.cuda(), self.grid_y.cuda()
|
||||||
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
|
self.anchor_w, self.anchor_h = self.anchor_w.cuda(), self.anchor_h.cuda()
|
||||||
# self.scaled_anchors = self.scaled_anchors.cuda()
|
|
||||||
|
|
||||||
# x.view(4, 650, 19, 19) -- > (4, 10, 19, 19, 65) # (bs, anchors, grid, grid, classes + xywh)
|
# x.view(4, 650, 19, 19) -- > (4, 10, 19, 19, 65) # (bs, anchors, grid, grid, classes + xywh)
|
||||||
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
|
||||||
|
@ -132,11 +130,9 @@ class YOLOLayer(nn.Module):
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
BCEWithLogitsLoss1 = nn.BCEWithLogitsLoss(size_average=False) # version 0.4.0
|
BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
|
||||||
BCEWithLogitsLoss0 = nn.BCEWithLogitsLoss()
|
MSELoss = nn.MSELoss() # version 0.4.0
|
||||||
# BCEWithLogitsLoss2 = nn.BCEWithLogitsLoss(size_average=True)
|
# CrossEntropyLoss = nn.CrossEntropyLoss()
|
||||||
MSELoss = nn.MSELoss(size_average=False) # version 0.4.0
|
|
||||||
CrossEntropyLoss = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
if requestPrecision:
|
if requestPrecision:
|
||||||
gx = self.grid_x[:, :, :nG, :nG]
|
gx = self.grid_x[:, :, :nG, :nG]
|
||||||
|
@ -154,21 +150,21 @@ class YOLOLayer(nn.Module):
|
||||||
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
|
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
|
||||||
|
|
||||||
# Mask outputs to ignore non-existing objects (but keep confidence predictions)
|
# Mask outputs to ignore non-existing objects (but keep confidence predictions)
|
||||||
nM = mask.sum().float()
|
nM = mask.sum()
|
||||||
nGT = sum([len(x) for x in targets])
|
nGT = sum([len(x) for x in targets])
|
||||||
if nM > 0:
|
if nM > 0:
|
||||||
lx = 5 * MSELoss(x[mask], tx[mask])
|
lx = 5 * MSELoss(x[mask], tx[mask])
|
||||||
ly = 5 * MSELoss(y[mask], ty[mask])
|
ly = 5 * MSELoss(y[mask], ty[mask])
|
||||||
lw = 5 * MSELoss(w[mask], tw[mask])
|
lw = 5 * MSELoss(w[mask], tw[mask])
|
||||||
lh = 5 * MSELoss(h[mask], th[mask])
|
lh = 5 * MSELoss(h[mask], th[mask])
|
||||||
lconf = 1.5 * BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float())
|
lconf = 1.5 * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float())
|
||||||
|
|
||||||
lcls = nM * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
|
# lcls = CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
|
||||||
# lcls = BCEWithLogitsLoss1(pred_cls[mask], tcls.float())
|
lcls = BCEWithLogitsLoss(pred_cls[mask], tcls.float())
|
||||||
else:
|
else:
|
||||||
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
|
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
|
||||||
|
|
||||||
lconf += nM * BCEWithLogitsLoss0(pred_conf[~mask], mask[~mask].float())
|
lconf += BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float())
|
||||||
|
|
||||||
loss = lx + ly + lw + lh + lconf + lcls
|
loss = lx + ly + lw + lh + lconf + lcls
|
||||||
i = torch.sigmoid(pred_conf[~mask]) > 0.99
|
i = torch.sigmoid(pred_conf[~mask]) > 0.99
|
||||||
|
|
2
train.py
2
train.py
|
@ -94,7 +94,7 @@ def main(opt):
|
||||||
epoch += start_epoch
|
epoch += start_epoch
|
||||||
|
|
||||||
# img_size = random.choice(range(10, 20)) * 32
|
# img_size = random.choice(range(10, 20)) * 32
|
||||||
# dataloader = ListDataset(train_path, batch_size=opt.batch_size, img_size=img_size, targets_path=targets_path)
|
# dataloader = ListDataset(train_path, batch_size=opt.batch_size, img_size=img_size)
|
||||||
# print('Running image size %g' % img_size)
|
# print('Running image size %g' % img_size)
|
||||||
|
|
||||||
# Update scheduler
|
# Update scheduler
|
||||||
|
|
Loading…
Reference in New Issue