updates
This commit is contained in:
		
							parent
							
								
									be01fc357b
								
							
						
					
					
						commit
						eb81c0b9ae
					
				
							
								
								
									
										11
									
								
								test.py
								
								
								
								
							
							
						
						
									
										11
									
								
								test.py
								
								
								
								
							|  | @ -74,6 +74,7 @@ def test(cfg, | |||
|         imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0 | ||||
|         targets = targets.to(device) | ||||
|         _, _, height, width = imgs.shape  # batch size, channels, height, width | ||||
|         whwh = torch.Tensor([width, height, width, height]).to(device) | ||||
| 
 | ||||
|         # Plot images with bounding boxes | ||||
|         f = 'test_batch%g.png' % batch_i  # filename | ||||
|  | @ -126,13 +127,13 @@ def test(cfg, | |||
|                                   'score': floatn(d[4], 5)}) | ||||
| 
 | ||||
|             # Assign all predictions as incorrect | ||||
|             correct = torch.zeros(len(pred), niou, dtype=torch.bool) | ||||
|             correct = torch.zeros(len(pred), niou, dtype=torch.bool, device=device) | ||||
|             if nl: | ||||
|                 detected = []  # target indices | ||||
|                 tcls_tensor = labels[:, 0] | ||||
| 
 | ||||
|                 # target boxes | ||||
|                 tbox = xywh2xyxy(labels[:, 1:5]) * torch.Tensor([width, height, width, height]).to(device) | ||||
|                 tbox = xywh2xyxy(labels[:, 1:5]) * whwh | ||||
| 
 | ||||
|                 # Per target class | ||||
|                 for cls in torch.unique(tcls_tensor): | ||||
|  | @ -140,7 +141,7 @@ def test(cfg, | |||
|                     pi = (cls == pred[:, 5]).nonzero().view(-1)  # target indices | ||||
| 
 | ||||
|                     # Search for detections | ||||
|                     if len(pi): | ||||
|                     if pi.shape[0]: | ||||
|                         # Prediction to target ious | ||||
|                         ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1)  # best ious, indices | ||||
| 
 | ||||
|  | @ -149,12 +150,12 @@ def test(cfg, | |||
|                             d = ti[i[j]]  # detected target | ||||
|                             if d not in detected: | ||||
|                                 detected.append(d) | ||||
|                                 correct[pi[j]] = (ious[j] > iouv).cpu()  # iou_thres is 1xn | ||||
|                                 correct[pi[j]] = ious[j] > iouv  # iou_thres is 1xn | ||||
|                                 if len(detected) == nl:  # all targets already located in image | ||||
|                                     break | ||||
| 
 | ||||
|             # Append statistics (correct, conf, pcls, tcls) | ||||
|             stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) | ||||
|             stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) | ||||
| 
 | ||||
|     # Compute statistics | ||||
|     stats = [np.concatenate(x, 0) for x in zip(*stats)]  # to numpy | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue