updates
This commit is contained in:
parent
3f21d5bb2e
commit
3bea4da604
|
@ -150,9 +150,6 @@ class YOLOLayer(nn.Module):
|
||||||
p_conf = p[..., 4] # Conf
|
p_conf = p[..., 4] # Conf
|
||||||
p_cls = p[..., 5:] # Class
|
p_cls = p[..., 5:] # Class
|
||||||
|
|
||||||
if p.is_cuda:
|
|
||||||
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec.cuda(), self.nA, self.nC, nG)
|
|
||||||
else:
|
|
||||||
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
|
txy, twh, mask, tcls = build_targets(targets, self.anchor_vec, self.nA, self.nC, nG)
|
||||||
|
|
||||||
tcls = tcls[mask]
|
tcls = tcls[mask]
|
||||||
|
@ -164,7 +161,6 @@ class YOLOLayer(nn.Module):
|
||||||
nM = mask.sum().float() # number of anchors (assigned to targets)
|
nM = mask.sum().float() # number of anchors (assigned to targets)
|
||||||
k = 1 # nM / bs
|
k = 1 # nM / bs
|
||||||
if nM > 0:
|
if nM > 0:
|
||||||
print(xy.shape, txy.shape, mask.shape)
|
|
||||||
lxy = k * MSELoss(xy[mask], txy[mask])
|
lxy = k * MSELoss(xy[mask], txy[mask])
|
||||||
lwh = k * MSELoss(wh[mask], twh[mask])
|
lwh = k * MSELoss(wh[mask], twh[mask])
|
||||||
|
|
||||||
|
|
4
train.py
4
train.py
|
@ -48,8 +48,8 @@ def train(
|
||||||
# Load weights to resume from
|
# Load weights to resume from
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
|
||||||
# if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
# model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
model.to(device).train()
|
model.to(device).train()
|
||||||
|
|
||||||
# Transfer learning (train only YOLO layers)
|
# Transfer learning (train only YOLO layers)
|
||||||
|
|
|
@ -17,7 +17,7 @@ def select_device(force_cpu=False):
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
print('Found %g GPUs' % torch.cuda.device_count())
|
print('Found %g GPUs' % torch.cuda.device_count())
|
||||||
print('WARNING Using GPU0 Only: https://github.com/ultralytics/yolov3/issues/21')
|
print('WARNING Multi-GPU Issue: https://github.com/ultralytics/yolov3/issues/21')
|
||||||
# torch.cuda.set_device(0) # OPTIONAL: Set your GPU if multiple available
|
# torch.cuda.set_device(0) # OPTIONAL: Set your GPU if multiple available
|
||||||
# # print('Using ', torch.cuda.device_count(), ' GPUs')
|
# # print('Using ', torch.cuda.device_count(), ' GPUs')
|
||||||
|
|
||||||
|
|
|
@ -242,12 +242,6 @@ def build_targets(target, anchor_vec, nA, nC, nG):
|
||||||
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
tconf = torch.ByteTensor(nB, nA, nG, nG).fill_(0)
|
||||||
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
tcls = torch.ByteTensor(nB, nA, nG, nG, nC).fill_(0) # nC = number of classes
|
||||||
|
|
||||||
if anchor_vec.is_cuda:
|
|
||||||
txy = txy.cuda()
|
|
||||||
twh = twh.cuda()
|
|
||||||
tconf = tconf.cuda()
|
|
||||||
tcls = tcls.cuda()
|
|
||||||
|
|
||||||
for b in range(nB):
|
for b in range(nB):
|
||||||
t = target[b]
|
t = target[b]
|
||||||
nTb = len(t) # number of targets
|
nTb = len(t) # number of targets
|
||||||
|
@ -263,8 +257,6 @@ def build_targets(target, anchor_vec, nA, nC, nG):
|
||||||
box1 = gwh
|
box1 = gwh
|
||||||
box2 = anchor_vec.unsqueeze(1)
|
box2 = anchor_vec.unsqueeze(1)
|
||||||
|
|
||||||
print(box1.device, box2.device)
|
|
||||||
|
|
||||||
inter_area = torch.min(box1, box2).prod(2)
|
inter_area = torch.min(box1, box2).prod(2)
|
||||||
iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
|
iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue