This commit is contained in:
Glenn Jocher 2019-07-25 18:18:40 +02:00
parent bfbc54666e
commit 88eea8f147
3 changed files with 12 additions and 8 deletions

View File

@ -139,6 +139,7 @@ class YOLOLayer(nn.Module):
return torch.cat((xy / ngu, wh, p_conf, p_cls), 2).squeeze().t() return torch.cat((xy / ngu, wh, p_conf, p_cls), 2).squeeze().t()
else: # inference else: # inference
# s = 1.5 # scale_xy (pxy = pxy * s - (s - 1) / 2)
io = p.clone() # inference output io = p.clone() # inference output
io[..., 0:2] = torch.sigmoid(io[..., 0:2]) + self.grid_xy # xy io[..., 0:2] = torch.sigmoid(io[..., 0:2]) + self.grid_xy # xy
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method

View File

@ -40,8 +40,6 @@ def exif_size(img):
class LoadImages: # for inference class LoadImages: # for inference
def __init__(self, path, img_size=416): def __init__(self, path, img_size=416):
self.height = img_size
files = [] files = []
if os.path.isdir(path): if os.path.isdir(path):
files = sorted(glob.glob('%s/*.*' % path)) files = sorted(glob.glob('%s/*.*' % path))
@ -52,6 +50,7 @@ class LoadImages: # for inference
videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats] videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
nI, nV = len(images), len(videos) nI, nV = len(images), len(videos)
self.img_size = img_size
self.files = images + videos self.files = images + videos
self.nF = nI + nV # number of files self.nF = nI + nV # number of files
self.video_flag = [False] * nI + [True] * nV self.video_flag = [False] * nI + [True] * nV
@ -96,7 +95,7 @@ class LoadImages: # for inference
print('image %g/%g %s: ' % (self.count, self.nF, path), end='') print('image %g/%g %s: ' % (self.count, self.nF, path), end='')
# Padded resize # Padded resize
img, *_ = letterbox(img0, new_shape=self.height) img, *_ = letterbox(img0, new_shape=self.img_size)
# Normalize RGB # Normalize RGB
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB
@ -117,8 +116,10 @@ class LoadImages: # for inference
class LoadWebcam: # for inference class LoadWebcam: # for inference
def __init__(self, img_size=416): def __init__(self, img_size=416):
self.cam = cv2.VideoCapture(0) self.img_size = img_size
self.height = img_size self.cam = cv2.VideoCapture(0) # local camera
# self.cam = cv2.VideoCapture('rtsp://192.168.1.64/1') # IP camera
# self.cam = cv2.VideoCapture('rtsp://username:password@192.168.1.64/1') # IP camera with login
def __iter__(self): def __iter__(self):
self.count = -1 self.count = -1
@ -138,7 +139,7 @@ class LoadWebcam: # for inference
print('webcam %g: ' % self.count, end='') print('webcam %g: ' % self.count, end='')
# Padded resize # Padded resize
img, *_ = letterbox(img0, new_shape=self.height) img, *_ = letterbox(img0, new_shape=self.img_size)
# Normalize RGB # Normalize RGB
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB

View File

@ -304,12 +304,14 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
tobj[b, a, gj, gi] = 1.0 # obj tobj[b, a, gj, gi] = 1.0 # obj
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment) # pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
# s = 1.5 # scale_xy
pxy = torch.sigmoid(pi[..., 0:2]) # * s - (s - 1) / 2
if giou_loss: if giou_loss:
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted pbox = torch.cat((pxy, torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
else: else:
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
tclsm = torch.zeros_like(pi[..., 5:]) tclsm = torch.zeros_like(pi[..., 5:])