Pseudo Labeling (#1149)

* Added pseudo labeling

* Delete print_test.py

* Refactor label generation

* Update detect.py

* Update detect.py

* Update utils.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
orcund 2020-05-16 21:09:57 +03:00 committed by GitHub
parent 3f27ef1253
commit 3a71daf4bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 2 deletions

View File

@ -102,7 +102,7 @@ def detect(save_img=False):
pred = apply_classifier(pred, modelc, img, im0s)
# Process detections
for i, det in enumerate(pred): # detections per image
for i, det in enumerate(pred): # detections for image i
if webcam: # batch_size >= 1
p, s, im0 = path[i], '%g: ' % i, im0s[i]
else:
@ -110,6 +110,7 @@ def detect(save_img=False):
save_path = str(Path(out) / Path(p).name)
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0s.shape)[[1, 0, 1, 0]] #  normalization gain whwh
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
@ -122,8 +123,9 @@ def detect(save_img=False):
# Write results
for *xyxy, conf, cls in det:
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
with open(save_path + '.txt', 'a') as file:
file.write(('%g ' * 6 + '\n') % (*xyxy, cls, conf))
file.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf)