This commit is contained in:
Glenn Jocher 2019-08-23 17:18:59 +02:00
parent d2ef817b1f
commit 5f2b551818
3 changed files with 23 additions and 22 deletions

View File

@ -4,13 +4,11 @@ from utils.parse_config import *
from utils.utils import *
ONNX_EXPORT = False
arc = 'normal' # (normal, uCE, uBCE, uBCEs) detection architectures
def create_modules(module_defs, img_size):
"""
Constructs module list of layer blocks from module configuration in module_defs
"""
def create_modules(module_defs, img_size, arc):
# Constructs module list of layer blocks from module configuration in module_defs
hyperparams = module_defs.pop(0)
output_filters = [int(hyperparams['channels'])]
module_list = nn.ModuleList()
@ -74,11 +72,12 @@ def create_modules(module_defs, img_size):
modules = YOLOLayer(anchors=mdef['anchors'][mask], # anchor list
nc=int(mdef['classes']), # number of classes
img_size=img_size, # (416, 416)
yolo_index=yolo_index) # 0, 1 or 2
yolo_index=yolo_index, # 0, 1 or 2
arc=arc) # yolo architecture
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
try:
if arc == 'normal':
if arc == 'default':
b = [-5.0, -4.0] # obj, cls
elif arc == 'uCE': # unified CE (1 background + 80 classes)
b = [3.0, -4.0] # obj, cls
@ -113,7 +112,7 @@ class Swish(nn.Module):
class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, yolo_index):
def __init__(self, anchors, nc, img_size, yolo_index, arc):
super(YOLOLayer, self).__init__()
self.anchors = torch.Tensor(anchors)
@ -121,6 +120,7 @@ class YOLOLayer(nn.Module):
self.nc = nc # number of classes (80)
self.nx = 0 # initialize number of x gridpoints
self.ny = 0 # initialize number of y gridpoints
self.arc = arc
if ONNX_EXPORT: # grids must be computed in __init__
stride = [32, 16, 8][yolo_index] # stride of this layer
@ -175,12 +175,12 @@ class YOLOLayer(nn.Module):
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
io[..., :4] *= self.stride
if arc == 'normal':
if self.arc == 'default':
torch.sigmoid_(io[..., 4:])
elif arc == 'uCE': # unified CE (1 background + 80 classes)
elif self.arc == 'uCE': # unified CE (1 background + 80 classes)
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
io[..., 4] = 1
elif arc == 'uBCE': # unified BCE (80 classes)
elif self.arc == 'uBCE': # unified BCE (80 classes)
torch.sigmoid_(io[..., 5:])
io[..., 4] = 1
@ -192,13 +192,13 @@ class YOLOLayer(nn.Module):
class Darknet(nn.Module):
"""YOLOv3 object detection model"""
# YOLOv3 object detection model
def __init__(self, cfg, img_size=(416, 416)):
def __init__(self, cfg, img_size=(416, 416), arc='default'):
super(Darknet, self).__init__()
self.module_defs = parse_model_cfg(cfg)
self.module_list, self.routs = create_modules(self.module_defs, img_size)
self.module_list, self.routs = create_modules(self.module_defs, img_size, arc)
self.yolo_layers = get_yolo_layers(self)
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346

View File

@ -84,7 +84,7 @@ def train():
nc = int(data_dict['classes']) # number of classes
# Initialize model
model = Darknet(cfg).to(device)
model = Darknet(cfg, arc=opt.arc).to(device)
# Optimizer
# optimizer = optim.Adam(model.parameters(), lr=hyp['lr0'], weight_decay=hyp['weight_decay'])
@ -259,7 +259,7 @@ def train():
pred = model(imgs)
# Compute loss
loss, loss_items = compute_loss(pred, targets, model)
loss, loss_items = compute_loss(pred, targets, model, arc=opt.arc)
if torch.isnan(loss):
print('WARNING: nan loss detected, ending training')
return results
@ -367,6 +367,7 @@ if __name__ == '__main__':
parser.add_argument('--img-weights', action='store_true', help='select training images by weight')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--weights', type=str, default='', help='initial weights') # i.e. weights/darknet.53.conv.74
parser.add_argument('--arc', type=str, default='default', help='yolo architecture') # default, uCE, uBCE
opt = parser.parse_args()
opt.weights = 'weights/last.pt' if opt.resume else opt.weights
print(opt)

View File

@ -312,7 +312,7 @@ class FocalLoss(nn.Module):
return loss
def compute_loss(p, targets, model): # predictions, targets, model
def compute_loss(p, targets, model, arc='default'): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
@ -321,12 +321,12 @@ def compute_loss(p, targets, model): # predictions, targets, model
# Define criteria
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
# CE = nn.CrossEntropyLoss(weight=model.class_weights)
BCE = nn.BCEWithLogitsLoss()
CE = nn.CrossEntropyLoss() # weight=model.class_weights
# Compute losses
bs = p[0].shape[0] # batch size
k = bs / 64 # loss gain
arc = 'normal' # (normal, uCE, uBCE, uBCEs) detection architectures
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi[..., 0]) # target obj
@ -344,7 +344,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
lbox += (1.0 - giou).mean() # giou loss
if arc == 'normal' and model.nc > 1: # cls loss (only if multiple classes)
if arc == 'default' and model.nc > 1: # cls loss (only if multiple classes)
t = torch.zeros_like(ps[:, 5:]) # targets
t[range(nb), tcls[i]] = 1.0
lcls += BCEcls(ps[:, 5:], t) # BCE
@ -354,7 +354,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
# with open('targets.txt', 'a') as file:
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
if arc == 'normal':
if arc == 'default': # (default, uCE, uBCE) detection architectures
lobj += BCEobj(pi[..., 4], tobj) # obj loss
elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20
@ -367,7 +367,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
t = torch.zeros_like(pi[..., 5:]) # targets
if nb:
t[b, a, gj, gi, tcls[i]] = 1.0
lobj += BCEobj(pi[..., 5:], t)
lobj += BCE(pi[..., 5:], t)
lbox *= k * h['giou']
lobj *= k * h['obj']