updates
This commit is contained in:
parent
bfbc54666e
commit
88eea8f147
|
@ -139,6 +139,7 @@ class YOLOLayer(nn.Module):
|
|||
return torch.cat((xy / ngu, wh, p_conf, p_cls), 2).squeeze().t()
|
||||
|
||||
else: # inference
|
||||
# s = 1.5 # scale_xy (pxy = pxy * s - (s - 1) / 2)
|
||||
io = p.clone() # inference output
|
||||
io[..., 0:2] = torch.sigmoid(io[..., 0:2]) + self.grid_xy # xy
|
||||
io[..., 2:4] = torch.exp(io[..., 2:4]) * self.anchor_wh # wh yolo method
|
||||
|
|
|
@ -40,8 +40,6 @@ def exif_size(img):
|
|||
|
||||
class LoadImages: # for inference
|
||||
def __init__(self, path, img_size=416):
|
||||
self.height = img_size
|
||||
|
||||
files = []
|
||||
if os.path.isdir(path):
|
||||
files = sorted(glob.glob('%s/*.*' % path))
|
||||
|
@ -52,6 +50,7 @@ class LoadImages: # for inference
|
|||
videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
|
||||
nI, nV = len(images), len(videos)
|
||||
|
||||
self.img_size = img_size
|
||||
self.files = images + videos
|
||||
self.nF = nI + nV # number of files
|
||||
self.video_flag = [False] * nI + [True] * nV
|
||||
|
@ -96,7 +95,7 @@ class LoadImages: # for inference
|
|||
print('image %g/%g %s: ' % (self.count, self.nF, path), end='')
|
||||
|
||||
# Padded resize
|
||||
img, *_ = letterbox(img0, new_shape=self.height)
|
||||
img, *_ = letterbox(img0, new_shape=self.img_size)
|
||||
|
||||
# Normalize RGB
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB
|
||||
|
@ -117,8 +116,10 @@ class LoadImages: # for inference
|
|||
|
||||
class LoadWebcam: # for inference
|
||||
def __init__(self, img_size=416):
|
||||
self.cam = cv2.VideoCapture(0)
|
||||
self.height = img_size
|
||||
self.img_size = img_size
|
||||
self.cam = cv2.VideoCapture(0) # local camera
|
||||
# self.cam = cv2.VideoCapture('rtsp://192.168.1.64/1') # IP camera
|
||||
# self.cam = cv2.VideoCapture('rtsp://username:password@192.168.1.64/1') # IP camera with login
|
||||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
|
@ -138,7 +139,7 @@ class LoadWebcam: # for inference
|
|||
print('webcam %g: ' % self.count, end='')
|
||||
|
||||
# Padded resize
|
||||
img, *_ = letterbox(img0, new_shape=self.height)
|
||||
img, *_ = letterbox(img0, new_shape=self.img_size)
|
||||
|
||||
# Normalize RGB
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB
|
||||
|
|
|
@ -304,12 +304,14 @@ def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, mo
|
|||
tobj[b, a, gj, gi] = 1.0 # obj
|
||||
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
|
||||
|
||||
# s = 1.5 # scale_xy
|
||||
pxy = torch.sigmoid(pi[..., 0:2]) # * s - (s - 1) / 2
|
||||
if giou_loss:
|
||||
pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
|
||||
pbox = torch.cat((pxy, torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1) # predicted
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||
lxy += (k * h['giou']) * (1.0 - giou).mean() # giou loss
|
||||
else:
|
||||
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
|
||||
lxy += (k * h['xy']) * MSE(pxy, txy[i]) # xy loss
|
||||
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
|
||||
|
||||
tclsm = torch.zeros_like(pi[..., 5:])
|
||||
|
|
Loading…
Reference in New Issue