Remove deprecated --arc architecture options, implement --arc default for all cases
This commit is contained in:
parent
77c6c01970
commit
448c4a6e1f
|
@ -7,7 +7,7 @@ from utils.utils import *
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = False
|
||||||
|
|
||||||
|
|
||||||
def create_modules(module_defs, img_size, arc):
|
def create_modules(module_defs, img_size):
|
||||||
# Constructs module list of layer blocks from module configuration in module_defs
|
# Constructs module list of layer blocks from module configuration in module_defs
|
||||||
|
|
||||||
hyperparams = module_defs.pop(0)
|
hyperparams = module_defs.pop(0)
|
||||||
|
@ -250,11 +250,11 @@ class YOLOLayer(nn.Module):
|
||||||
class Darknet(nn.Module):
|
class Darknet(nn.Module):
|
||||||
# YOLOv3 object detection model
|
# YOLOv3 object detection model
|
||||||
|
|
||||||
def __init__(self, cfg, img_size=(416, 416), arc='default'):
|
def __init__(self, cfg, img_size=(416, 416)):
|
||||||
super(Darknet, self).__init__()
|
super(Darknet, self).__init__()
|
||||||
|
|
||||||
self.module_defs = parse_model_cfg(cfg)
|
self.module_defs = parse_model_cfg(cfg)
|
||||||
self.module_list, self.routs = create_modules(self.module_defs, img_size, arc)
|
self.module_list, self.routs = create_modules(self.module_defs, img_size)
|
||||||
self.yolo_layers = get_yolo_layers(self)
|
self.yolo_layers = get_yolo_layers(self)
|
||||||
|
|
||||||
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
||||||
|
|
6
train.py
6
train.py
|
@ -32,7 +32,7 @@ hyp = {'giou': 3.54, # giou loss gain
|
||||||
'lrf': -4., # final LambdaLR learning rate = lr0 * (10 ** lrf)
|
'lrf': -4., # final LambdaLR learning rate = lr0 * (10 ** lrf)
|
||||||
'momentum': 0.937, # SGD momentum
|
'momentum': 0.937, # SGD momentum
|
||||||
'weight_decay': 0.000484, # optimizer weight decay
|
'weight_decay': 0.000484, # optimizer weight decay
|
||||||
'fl_gamma': 1.5, # focal loss gamma
|
'fl_gamma': 0.0, # focal loss gamma (efficientDet default is gamma=1.5)
|
||||||
'hsv_h': 0.0138, # image HSV-Hue augmentation (fraction)
|
'hsv_h': 0.0138, # image HSV-Hue augmentation (fraction)
|
||||||
'hsv_s': 0.678, # image HSV-Saturation augmentation (fraction)
|
'hsv_s': 0.678, # image HSV-Saturation augmentation (fraction)
|
||||||
'hsv_v': 0.36, # image HSV-Value augmentation (fraction)
|
'hsv_v': 0.36, # image HSV-Value augmentation (fraction)
|
||||||
|
@ -77,7 +77,7 @@ def train():
|
||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg, arc=opt.arc).to(device)
|
model = Darknet(cfg).to(device)
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
|
||||||
|
@ -192,7 +192,6 @@ def train():
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
model.nc = nc # attach number of classes to model
|
model.nc = nc # attach number of classes to model
|
||||||
model.arc = opt.arc # attach yolo architecture
|
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.gr = 0.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
model.gr = 0.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
|
@ -406,7 +405,6 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
|
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
|
||||||
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
||||||
parser.add_argument('--weights', type=str, default='weights/yolov3-spp-ultralytics.pt', help='initial weights path')
|
parser.add_argument('--weights', type=str, default='weights/yolov3-spp-ultralytics.pt', help='initial weights path')
|
||||||
parser.add_argument('--arc', type=str, default='default', help='yolo architecture') # default, uCE, uBCE
|
|
||||||
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
|
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
|
||||||
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
|
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)')
|
||||||
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
|
parser.add_argument('--adam', action='store_true', help='use adam optimizer')
|
||||||
|
|
|
@ -377,7 +377,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||||
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||||
h = model.hyp # hyperparameters
|
h = model.hyp # hyperparameters
|
||||||
arc = model.arc # architecture
|
|
||||||
red = 'mean' # Loss reduction (sum or mean)
|
red = 'mean' # Loss reduction (sum or mean)
|
||||||
|
|
||||||
# Define criteria
|
# Define criteria
|
||||||
|
@ -388,8 +387,9 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
cp, cn = smooth_BCE(eps=0.0)
|
cp, cn = smooth_BCE(eps=0.0)
|
||||||
|
|
||||||
# focal loss
|
# focal loss
|
||||||
if 'F' in arc:
|
g = h['fl_gamma'] # focal loss gamma
|
||||||
BCEcls, BCEobj = FocalLoss(BCEcls, h['fl_gamma']), FocalLoss(BCEobj, h['fl_gamma'])
|
if g > 0:
|
||||||
|
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
np, ng = 0, 0 # number grid points, targets
|
np, ng = 0, 0 # number grid points, targets
|
||||||
|
|
Loading…
Reference in New Issue