test/inference time augmentation
This commit is contained in:
		
							parent
							
								
									adba66c3a6
								
							
						
					
					
						commit
						07d2f0ad03
					
				
							
								
								
									
										15
									
								
								test.py
								
								
								
								
							
							
						
						
									
										15
									
								
								test.py
								
								
								
								
							| 
						 | 
					@ -73,7 +73,7 @@ def test(cfg,
 | 
				
			||||||
    for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
 | 
					    for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
 | 
				
			||||||
        imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
 | 
					        imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
 | 
				
			||||||
        targets = targets.to(device)
 | 
					        targets = targets.to(device)
 | 
				
			||||||
        _, _, height, width = imgs.shape  # batch size, channels, height, width
 | 
					        nb, _, height, width = imgs.shape  # batch size, channels, height, width
 | 
				
			||||||
        whwh = torch.Tensor([width, height, width, height]).to(device)
 | 
					        whwh = torch.Tensor([width, height, width, height]).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Plot images with bounding boxes
 | 
					        # Plot images with bounding boxes
 | 
				
			||||||
| 
						 | 
					@ -83,11 +83,24 @@ def test(cfg,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Disable gradients
 | 
					        # Disable gradients
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            aug = False  # augment https://github.com/ultralytics/yolov3/issues/931
 | 
				
			||||||
 | 
					            if aug:
 | 
				
			||||||
 | 
					                imgs = torch.cat((imgs,
 | 
				
			||||||
 | 
					                                  imgs.flip(3),  # flip-lr
 | 
				
			||||||
 | 
					                                  torch_utils.scale_img(imgs, 0.7),  # scale
 | 
				
			||||||
 | 
					                                  ), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Run model
 | 
					            # Run model
 | 
				
			||||||
            t = torch_utils.time_synchronized()
 | 
					            t = torch_utils.time_synchronized()
 | 
				
			||||||
            inf_out, train_out = model(imgs)  # inference and training outputs
 | 
					            inf_out, train_out = model(imgs)  # inference and training outputs
 | 
				
			||||||
            t0 += torch_utils.time_synchronized() - t
 | 
					            t0 += torch_utils.time_synchronized() - t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if aug:
 | 
				
			||||||
 | 
					                x = torch.split(inf_out, nb, dim=0)
 | 
				
			||||||
 | 
					                x[1][..., 0] = width - x[1][..., 0]  # flip lr
 | 
				
			||||||
 | 
					                x[2][..., :4] /= 0.7  # scale
 | 
				
			||||||
 | 
					                inf_out = torch.cat(x, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Compute loss
 | 
					            # Compute loss
 | 
				
			||||||
            if hasattr(model, 'hyp'):  # if model has loss hyperparameters
 | 
					            if hasattr(model, 'hyp'):  # if model has loss hyperparameters
 | 
				
			||||||
                loss += compute_loss(train_out, targets, model)[1][:3]  # GIoU, obj, cls
 | 
					                loss += compute_loss(train_out, targets, model)[1][:3]  # GIoU, obj, cls
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,6 +5,7 @@ from copy import deepcopy
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.backends.cudnn as cudnn
 | 
					import torch.backends.cudnn as cudnn
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def init_seeds(seed=0):
 | 
					def init_seeds(seed=0):
 | 
				
			||||||
| 
						 | 
					@ -105,6 +106,16 @@ def load_classifier(name='resnet101', n=2):
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def scale_img(img, r=1.0):  # img(16,3,256,416), r=ratio
 | 
				
			||||||
 | 
					    # scales a batch of pytorch images while retaining same input shape (cropped or grey-padded)
 | 
				
			||||||
 | 
					    h, w = img.shape[2:]
 | 
				
			||||||
 | 
					    s = (int(h * r), int(w * r))  # new size
 | 
				
			||||||
 | 
					    p = h - s[0], w - s[1]  # pad/crop pixels
 | 
				
			||||||
 | 
					    img = F.interpolate(img, size=s, mode='bilinear', align_corners=False)  # resize
 | 
				
			||||||
 | 
					    return F.pad(img, [0, p[1], 0, p[0]], value=0.5) if r < 1.0 else img[:, :, :p[0], :p[1]]  # pad/crop
 | 
				
			||||||
 | 
					    # cv2.imwrite('scaled.jpg', np.array(img[0].permute((1, 2, 0)) * 255.0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelEMA:
 | 
					class ModelEMA:
 | 
				
			||||||
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
 | 
					    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
 | 
				
			||||||
    Keep a moving average of everything in the model state_dict (parameters and buffers).
 | 
					    Keep a moving average of everything in the model state_dict (parameters and buffers).
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue