updates
This commit is contained in:
		
							parent
							
								
									bf66656b4e
								
							
						
					
					
						commit
						7b13af707d
					
				
							
								
								
									
										34
									
								
								test.py
								
								
								
								
							
							
						
						
									
										34
									
								
								test.py
								
								
								
								
							| 
						 | 
				
			
			@ -5,7 +5,7 @@ from utils.datasets import *
 | 
			
		|||
from utils.utils import *
 | 
			
		||||
 | 
			
		||||
parser = argparse.ArgumentParser(prog='test.py')
 | 
			
		||||
parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch')
 | 
			
		||||
parser.add_argument('-batch_size', type=int, default=64, help='size of each image batch')
 | 
			
		||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file')
 | 
			
		||||
parser.add_argument('-data_config_path', type=str, default='cfg/coco.data', help='path to data config file')
 | 
			
		||||
parser.add_argument('-weights_path', type=str, default='weights/yolov3.pt', help='path to weights file')
 | 
			
		||||
| 
						 | 
				
			
			@ -53,41 +53,35 @@ def main(opt):
 | 
			
		|||
    outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
 | 
			
		||||
    AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
 | 
			
		||||
    for batch_i, (imgs, targets) in enumerate(dataloader):
 | 
			
		||||
        imgs = imgs.to(device)
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            output = model(imgs)
 | 
			
		||||
            output = model(imgs.to(device))
 | 
			
		||||
            output = non_max_suppression(output, conf_thres=opt.conf_thres, nms_thres=opt.nms_thres)
 | 
			
		||||
 | 
			
		||||
        # Compute average precision for each sample
 | 
			
		||||
        for sample_i in range(len(targets)):
 | 
			
		||||
        for sample_i, (labels, detections) in enumerate(zip(targets, output)):
 | 
			
		||||
            correct = []
 | 
			
		||||
 | 
			
		||||
            # Get labels for sample where width is not zero (dummies)
 | 
			
		||||
            annotations = targets[sample_i]
 | 
			
		||||
            # Extract detections
 | 
			
		||||
            detections = output[sample_i]
 | 
			
		||||
 | 
			
		||||
            if detections is None:
 | 
			
		||||
                # If there are no detections but there are annotations mask as zero AP
 | 
			
		||||
                if annotations.size(0) != 0:
 | 
			
		||||
                    mAPs.append(0)
 | 
			
		||||
                # If there are no detections but there are labels mask as zero AP
 | 
			
		||||
                if labels.size(0) != 0:
 | 
			
		||||
                    mAPs.append(0), mR.append(0), mP.append(0)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # Get detections sorted by decreasing confidence scores
 | 
			
		||||
            detections = detections.cpu().numpy()
 | 
			
		||||
            detections = detections[np.argsort(-detections[:, 4])]
 | 
			
		||||
 | 
			
		||||
            # If no annotations add number of detections as incorrect
 | 
			
		||||
            if annotations.size(0) == 0:
 | 
			
		||||
            # If no labels add number of detections as incorrect
 | 
			
		||||
            if labels.size(0) == 0:
 | 
			
		||||
                # correct.extend([0 for _ in range(len(detections))])
 | 
			
		||||
                mAPs.append(0)
 | 
			
		||||
                mAPs.append(0), mR.append(0), mP.append(0)
 | 
			
		||||
                continue
 | 
			
		||||
            else:
 | 
			
		||||
                target_cls = annotations[:, 0]
 | 
			
		||||
                target_cls = labels[:, 0]
 | 
			
		||||
 | 
			
		||||
                # Extract target boxes as (x1, y1, x2, y2)
 | 
			
		||||
                target_boxes = xywh2xyxy(annotations[:, 1:5])
 | 
			
		||||
                target_boxes *= opt.img_size
 | 
			
		||||
                target_boxes = xywh2xyxy(labels[:, 1:5]) * opt.img_size
 | 
			
		||||
 | 
			
		||||
                detected = []
 | 
			
		||||
                for *pred_bbox, conf, obj_conf, obj_pred in detections:
 | 
			
		||||
| 
						 | 
				
			
			@ -98,7 +92,7 @@ def main(opt):
 | 
			
		|||
                    # Extract index of largest overlap
 | 
			
		||||
                    best_i = np.argmax(iou)
 | 
			
		||||
                    # If overlap exceeds threshold and classification is correct mark as correct
 | 
			
		||||
                    if iou[best_i] > opt.iou_thres and obj_pred == annotations[best_i, 0] and best_i not in detected:
 | 
			
		||||
                    if iou[best_i] > opt.iou_thres and obj_pred == labels[best_i, 0] and best_i not in detected:
 | 
			
		||||
                        correct.append(1)
 | 
			
		||||
                        detected.append(best_i)
 | 
			
		||||
                    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -123,7 +117,7 @@ def main(opt):
 | 
			
		|||
            mean_P = np.mean(mP)
 | 
			
		||||
 | 
			
		||||
            # Print image mAP and running mean mAP
 | 
			
		||||
            print(('%11s%11s' + '%11.3g' * 3) % (len(mAPs), len(dataloader) * opt.batch_size, mean_P, mean_R, mean_mAP))
 | 
			
		||||
            print(('%11s%11s' + '%11.3g' * 3) % (len(mAPs), dataloader.nF, mean_P, mean_R, mean_mAP))
 | 
			
		||||
 | 
			
		||||
    # Print mAP per class
 | 
			
		||||
    print('%11s' * 5 % ('Image', 'Total', 'P', 'R', 'mAP') + '\n\nmAP Per Class:')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue