GIoU to default
This commit is contained in:
		
							parent
							
								
									32a52dfb02
								
							
						
					
					
						commit
						70f6379601
					
				
							
								
								
									
										2
									
								
								test.py
								
								
								
								
							
							
						
						
									
										2
									
								
								test.py
								
								
								
								
							|  | @ -64,7 +64,7 @@ def test( | ||||||
| 
 | 
 | ||||||
|         # Plot images with bounding boxes |         # Plot images with bounding boxes | ||||||
|         if batch_i == 0 and not os.path.exists('test_batch0.jpg'): |         if batch_i == 0 and not os.path.exists('test_batch0.jpg'): | ||||||
|             plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg') |             plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg') | ||||||
| 
 | 
 | ||||||
|         # Run model |         # Run model | ||||||
|         inf_out, train_out = model(imgs)  # inference and training outputs |         inf_out, train_out = model(imgs)  # inference and training outputs | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								train.py
								
								
								
								
							
							
						
						
									
										4
									
								
								train.py
								
								
								
								
							|  | @ -178,7 +178,7 @@ def train( | ||||||
| 
 | 
 | ||||||
|         mloss = torch.zeros(5).to(device)  # mean losses |         mloss = torch.zeros(5).to(device)  # mean losses | ||||||
|         pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar |         pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar | ||||||
|         for i, (imgs, targets, _, _) in pbar: |         for i, (imgs, targets, paths, _) in pbar: | ||||||
|             imgs = imgs.to(device) |             imgs = imgs.to(device) | ||||||
|             targets = targets.to(device) |             targets = targets.to(device) | ||||||
| 
 | 
 | ||||||
|  | @ -192,7 +192,7 @@ def train( | ||||||
| 
 | 
 | ||||||
|             # Plot images with bounding boxes |             # Plot images with bounding boxes | ||||||
|             if epoch == 0 and i == 0: |             if epoch == 0 and i == 0: | ||||||
|                 plot_images(imgs=imgs, targets=targets, fname='train_batch%g.jpg' % i) |                 plot_images(imgs=imgs, targets=targets, paths=paths, fname='train_batch%g.jpg' % i) | ||||||
| 
 | 
 | ||||||
|             # SGD burn-in |             # SGD burn-in | ||||||
|             if epoch == 0 and i <= n_burnin: |             if epoch == 0 and i <= n_burnin: | ||||||
|  |  | ||||||
|  | @ -9,6 +9,7 @@ import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from PIL import Image | from PIL import Image | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
|  | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from . import torch_utils | from . import torch_utils | ||||||
| from . import google_utils | from . import google_utils | ||||||
|  | @ -611,7 +612,7 @@ def plot_wh_methods():  # from utils.utils import *; plot_wh_methods() | ||||||
|     fig.savefig('comparison.png', dpi=300) |     fig.savefig('comparison.png', dpi=300) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def plot_images(imgs, targets, fname='images.jpg'): | def plot_images(imgs, targets, paths=None, fname='images.jpg'): | ||||||
|     # Plots training images overlaid with targets |     # Plots training images overlaid with targets | ||||||
|     imgs = imgs.cpu().numpy() |     imgs = imgs.cpu().numpy() | ||||||
|     targets = targets.cpu().numpy() |     targets = targets.cpu().numpy() | ||||||
|  | @ -627,6 +628,8 @@ def plot_images(imgs, targets, fname='images.jpg'): | ||||||
|         plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0)) |         plt.subplot(ns, ns, i + 1).imshow(imgs[i].transpose(1, 2, 0)) | ||||||
|         plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') |         plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') | ||||||
|         plt.axis('off') |         plt.axis('off') | ||||||
|  |         if paths is not None: | ||||||
|  |             plt.title(Path(paths[i]).name, fontdict={'size': 8}) | ||||||
|     fig.tight_layout() |     fig.tight_layout() | ||||||
|     fig.savefig(fname, dpi=300) |     fig.savefig(fname, dpi=300) | ||||||
|     plt.close() |     plt.close() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue