updates
This commit is contained in:
		
							parent
							
								
									6cb3c61320
								
							
						
					
					
						commit
						26b115c306
					
				
							
								
								
									
										4
									
								
								test.py
								
								
								
								
							
							
						
						
									
										4
									
								
								test.py
								
								
								
								
							|  | @ -61,6 +61,10 @@ def test( | |||
|         targets = targets.to(device) | ||||
|         imgs = imgs.to(device) | ||||
| 
 | ||||
|         # Plot images with bounding boxes | ||||
|         if batch_i == 0 and not os.path.exists('test_batch0.jpg'): | ||||
|             plot_images(imgs=imgs, targets=targets, fname='test_batch0.jpg') | ||||
| 
 | ||||
|         # Run model | ||||
|         inf_out, train_out = model(imgs)  # inference and training outputs | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										14
									
								
								train.py
								
								
								
								
							
							
						
						
									
										14
									
								
								train.py
								
								
								
								
							|  | @ -104,6 +104,8 @@ def train( | |||
|     model_info(model) | ||||
|     nB = len(dataloader) | ||||
|     n_burnin = min(round(nB / 5 + 1), 1000)  # burn-in batches | ||||
|     os.remove('train_batch0.jpg') if os.path.exists('train_batch0.jpg') else None | ||||
|     os.remove('test_batch0.jpg') if os.path.exists('test_batch0.jpg') else None | ||||
|     for epoch in range(start_epoch, epochs): | ||||
|         model.train() | ||||
|         print(('\n%8s%12s' + '%10s' * 7) % ('Epoch', 'Batch', 'xy', 'wh', 'conf', 'cls', 'total', 'nTargets', 'time')) | ||||
|  | @ -127,16 +129,8 @@ def train( | |||
|                 continue | ||||
| 
 | ||||
|             # Plot images with bounding boxes | ||||
|             plot_images = False | ||||
|             if plot_images: | ||||
|                 fig = plt.figure(figsize=(10, 10)) | ||||
|                 for ip in range(len(imgs)): | ||||
|                     boxes = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy().T * img_size | ||||
|                     plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0)) | ||||
|                     plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') | ||||
|                     plt.axis('off') | ||||
|                 fig.tight_layout() | ||||
|                 fig.savefig('batch_%g.jpg' % i, dpi=fig.dpi) | ||||
|             if epoch == 0 and i == 0: | ||||
|                 plot_images(imgs=imgs, targets=targets, fname='train_batch0.jpg') | ||||
| 
 | ||||
|             # SGD burn-in | ||||
|             if epoch == 0 and i <= n_burnin: | ||||
|  |  | |||
|  | @ -487,6 +487,21 @@ def plot_wh_methods():  # from utils.utils import *; plot_wh_methods() | |||
|     fig.savefig('comparison.jpg', dpi=fig.dpi) | ||||
| 
 | ||||
| 
 | ||||
| def plot_images(imgs, targets, fname='images.jpg'): | ||||
|     fig = plt.figure(figsize=(10, 10)) | ||||
|     img_size = imgs.shape[3] | ||||
|     bs = imgs.shape[0]  # batch size | ||||
|     sp = np.ceil(bs ** 0.5)  # subplots | ||||
| 
 | ||||
|     for i in range(bs): | ||||
|         boxes = xywh2xyxy(targets[targets[:, 0] == i, 2:6]).numpy().T * img_size | ||||
|         plt.subplot(sp, sp, i + 1).imshow(imgs[i].numpy().transpose(1, 2, 0)) | ||||
|         plt.plot(boxes[[0, 2, 2, 0, 0]], boxes[[1, 1, 3, 3, 1]], '.-') | ||||
|         plt.axis('off') | ||||
|     fig.tight_layout() | ||||
|     fig.savefig(fname, dpi=fig.dpi) | ||||
| 
 | ||||
| 
 | ||||
| def plot_results(start=0, stop=0):  # from utils.utils import *; plot_results() | ||||
|     # Plot training results files 'results*.txt' | ||||
|     # import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt') | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue