updates
This commit is contained in:
		
							parent
							
								
									0e54731bb8
								
							
						
					
					
						commit
						0e17fb5905
					
				
							
								
								
									
										36
									
								
								test.py
								
								
								
								
							
							
						
						
									
										36
									
								
								test.py
								
								
								
								
							|  | @ -126,7 +126,7 @@ def test(cfg, | ||||||
|             # Assign all predictions as incorrect |             # Assign all predictions as incorrect | ||||||
|             correct = torch.zeros(len(pred), niou) |             correct = torch.zeros(len(pred), niou) | ||||||
|             if nl: |             if nl: | ||||||
|                 detected = [] |                 detected = []  # target indices | ||||||
|                 tcls_tensor = labels[:, 0] |                 tcls_tensor = labels[:, 0] | ||||||
| 
 | 
 | ||||||
|                 # target boxes |                 # target boxes | ||||||
|  | @ -134,27 +134,25 @@ def test(cfg, | ||||||
|                 tbox[:, [0, 2]] *= width |                 tbox[:, [0, 2]] *= width | ||||||
|                 tbox[:, [1, 3]] *= height |                 tbox[:, [1, 3]] *= height | ||||||
| 
 | 
 | ||||||
|                 # Search for correct predictions |                 # Per target class | ||||||
|                 for i, (*pbox, _, pcls) in enumerate(pred): |                 for cls in torch.unique(tcls_tensor): | ||||||
|  |                     ti = (cls == tcls_tensor).nonzero().view(-1)  # prediction indices | ||||||
|  |                     pi = (cls == pred[:, 5]).nonzero().view(-1)  # target indices | ||||||
| 
 | 
 | ||||||
|                     # Break if all targets already located in image |                     # Search for detections | ||||||
|                     if len(detected) == nl: |                     if len(pi): | ||||||
|  |                         # Prediction to target ious | ||||||
|  |                         ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1)  # best ious, indices | ||||||
|  | 
 | ||||||
|  |                         # Append detections | ||||||
|  |                         for j in (ious > iou_thres[0]).nonzero(): | ||||||
|  |                             d = ti[i[j]]  # detected target | ||||||
|  |                             if d not in detected: | ||||||
|  |                                 detected.append(d) | ||||||
|  |                                 correct[pi[j]] = (ious[j] > iou_thres).float()  # iou_thres is 1xn | ||||||
|  |                                 if len(detected) == nl:  # all targets already located in image | ||||||
|                                     break |                                     break | ||||||
| 
 | 
 | ||||||
|                     # Continue if predicted class not among image classes |  | ||||||
|                     if pcls.item() not in tcls: |  | ||||||
|                         continue |  | ||||||
| 
 |  | ||||||
|                     # Best iou, index between pred and targets |  | ||||||
|                     m = (pcls == tcls_tensor).nonzero().view(-1) |  | ||||||
|                     iou, j = bbox_iou(pbox, tbox[m]).max(0) |  | ||||||
|                     m = m[j] |  | ||||||
| 
 |  | ||||||
|                     # Per iou_thres 'correct' vector |  | ||||||
|                     if iou > iou_thres[0] and m not in detected: |  | ||||||
|                         detected.append(m) |  | ||||||
|                         correct[i] = iou > iou_thres |  | ||||||
| 
 |  | ||||||
|             # Append statistics (correct, conf, pcls, tcls) |             # Append statistics (correct, conf, pcls, tcls) | ||||||
|             stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) |             stats.append((correct, pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue