GIoU to default
This commit is contained in:
		
							parent
							
								
									7246dd855c
								
							
						
					
					
						commit
						b649a95c9a
					
				
							
								
								
									
										3
									
								
								test.py
								
								
								
								
							
							
						
						
									
										3
									
								
								test.py
								
								
								
								
							|  | @ -71,8 +71,7 @@ def test( | |||
| 
 | ||||
|         # Compute loss | ||||
|         if hasattr(model, 'hyp'):  # if model has loss hyperparameters | ||||
|             loss_i, _ = compute_loss(train_out, targets, model) | ||||
|             loss += loss_i.item() | ||||
|             loss += compute_loss(train_out, targets, model)[0].item() | ||||
| 
 | ||||
|         # Run NMS | ||||
|         output = non_max_suppression(inf_out, conf_thres=conf_thres, nms_thres=nms_thres) | ||||
|  |  | |||
							
								
								
									
										4
									
								
								train.py
								
								
								
								
							
							
						
						
									
										4
									
								
								train.py
								
								
								
								
							|  | @ -218,7 +218,7 @@ def train( | |||
|             pred = model(imgs) | ||||
| 
 | ||||
|             # Compute loss | ||||
|             loss, loss_items = compute_loss(pred, targets, model, giou_loss=opt.giou) | ||||
|             loss, loss_items = compute_loss(pred, targets, model, giou_loss=not opt.xywh) | ||||
|             if torch.isnan(loss): | ||||
|                 print('WARNING: nan loss detected, ending training') | ||||
|                 return results | ||||
|  | @ -320,7 +320,7 @@ if __name__ == '__main__': | |||
|     parser.add_argument('--num-workers', type=int, default=4, help='number of Pytorch DataLoader workers') | ||||
|     parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') | ||||
|     parser.add_argument('--notest', action='store_true', help='only test final epoch') | ||||
|     parser.add_argument('--giou', action='store_true', help='use GIoU loss instead of xy, wh loss') | ||||
|     parser.add_argument('--xywh', action='store_true', help='use xywh loss instead of GIoU loss') | ||||
|     parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') | ||||
|     parser.add_argument('--cloud-evolve', action='store_true', help='evolve hyperparameters from a cloud source') | ||||
|     parser.add_argument('--var', default=0, type=int, help='debug variable') | ||||
|  |  | |||
|  | @ -271,7 +271,7 @@ def wh_iou(box1, box2): | |||
|     return inter_area / union_area  # iou | ||||
| 
 | ||||
| 
 | ||||
| def compute_loss(p, targets, model, giou_loss=False):  # predictions, targets, model | ||||
| def compute_loss(p, targets, model, giou_loss=True):  # predictions, targets, model | ||||
|     ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor | ||||
|     lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0]) | ||||
|     txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets) | ||||
|  | @ -336,17 +336,17 @@ def build_targets(model, targets): | |||
|         if nt: | ||||
|             iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0) | ||||
| 
 | ||||
|             use_best = True | ||||
|             if use_best: | ||||
|             use_best_anchor = False | ||||
|             if use_best_anchor: | ||||
|                 iou, a = iou.max(0)  # best iou and anchor | ||||
|             else: | ||||
|             else:  # use all anchors | ||||
|                 na = len(layer.anchor_vec)  # number of anchors | ||||
|                 a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1) | ||||
|                 t = targets.repeat([na, 1]) | ||||
|                 gwh = gwh.repeat([na, 1]) | ||||
|                 iou = iou.view(-1)  # use all ious | ||||
| 
 | ||||
|             # reject below threshold ious (OPTIONAL, increases P, lowers R) | ||||
|             # reject anchors below iou_thres (OPTIONAL, increases P, lowers R) | ||||
|             reject = True | ||||
|             if reject: | ||||
|                 j = iou > iou_thres | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue