updates
This commit is contained in:
parent
c4f9e3891e
commit
6260ac266f
26
models.py
26
models.py
|
@ -77,12 +77,20 @@ def create_modules(module_defs, img_size, arc):
|
||||||
|
|
||||||
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
|
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
|
||||||
try:
|
try:
|
||||||
if arc == 'default':
|
if arc == 'defaultpw': # default with positive weights
|
||||||
b = [-4, -3.6] # obj, cls
|
b = [-4, -3.6] # obj, cls
|
||||||
elif arc == 'uCE': # unified CE (1 background + 80 classes)
|
if arc == 'default': # default no pw (40 cls, 80 obj)
|
||||||
b = [10, -0.1] # obj, cls
|
b = [-5.5, -4.0]
|
||||||
elif arc == 'uBCE': # unified BCE (80 classes)
|
elif arc == 'uBCE': # unified BCE (80 classes)
|
||||||
b = [0, -8.5] # obj, cls
|
b = [0, -8.5]
|
||||||
|
elif arc == 'uCE': # unified CE (1 background + 80 classes)
|
||||||
|
b = [10, -0.1]
|
||||||
|
elif arc == 'Fdefault': # Focal default no pw (28 cls, 21 obj, no pw)
|
||||||
|
b = [-2.1, -1.8]
|
||||||
|
elif arc == 'uFBCE': # unified FocalBCE (5120 obj, 80 classes)
|
||||||
|
b = [0, -3.5]
|
||||||
|
elif arc == 'uFCE': # unified FocalCE (64 cls, 1 background + 80 classes)
|
||||||
|
b = [7, -0.1]
|
||||||
|
|
||||||
bias = module_list[-1][0].bias.view(len(mask), -1) # 255 to 3x85
|
bias = module_list[-1][0].bias.view(len(mask), -1) # 255 to 3x85
|
||||||
bias[:, 4] += b[0] # obj
|
bias[:, 4] += b[0] # obj
|
||||||
|
@ -175,14 +183,14 @@ class YOLOLayer(nn.Module):
|
||||||
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
|
# io[..., 2:4] = ((torch.sigmoid(io[..., 2:4]) * 2) ** 3) * self.anchor_wh # wh power method
|
||||||
io[..., :4] *= self.stride
|
io[..., :4] *= self.stride
|
||||||
|
|
||||||
if self.arc == 'default':
|
if 'default' in self.arc: # seperate obj and cls
|
||||||
torch.sigmoid_(io[..., 4:])
|
torch.sigmoid_(io[..., 4:])
|
||||||
elif self.arc == 'uCE': # unified CE (1 background + 80 classes)
|
elif 'BCE' in self.arc: # unified BCE (80 classes)
|
||||||
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
|
|
||||||
io[..., 4] = 1
|
|
||||||
elif self.arc == 'uBCE': # unified BCE (80 classes)
|
|
||||||
torch.sigmoid_(io[..., 5:])
|
torch.sigmoid_(io[..., 5:])
|
||||||
io[..., 4] = 1
|
io[..., 4] = 1
|
||||||
|
elif 'CE' in self.arc: # unified CE (1 background + 80 classes)
|
||||||
|
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
|
||||||
|
io[..., 4] = 1
|
||||||
|
|
||||||
if self.nc == 1:
|
if self.nc == 1:
|
||||||
io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235
|
io[..., 5] = 1 # single-class model https://github.com/ultralytics/yolov3/issues/235
|
||||||
|
|
6
train.py
6
train.py
|
@ -44,6 +44,10 @@ def train():
|
||||||
accumulate = opt.accumulate # effective bs = batch_size * accumulate = 16 * 4 = 64
|
accumulate = opt.accumulate # effective bs = batch_size * accumulate = 16 * 4 = 64
|
||||||
weights = opt.weights # initial training weights
|
weights = opt.weights # initial training weights
|
||||||
|
|
||||||
|
if 'pw' not in opt.arc: # remove BCELoss positive weights
|
||||||
|
hyp['cls_pw'] = 0
|
||||||
|
hyp['obj_pw'] = 0
|
||||||
|
|
||||||
# Initialize
|
# Initialize
|
||||||
init_seeds()
|
init_seeds()
|
||||||
wdir = 'weights' + os.sep # weights dir
|
wdir = 'weights' + os.sep # weights dir
|
||||||
|
@ -359,7 +363,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--img-weights', action='store_true', help='select training images by weight')
|
parser.add_argument('--img-weights', action='store_true', help='select training images by weight')
|
||||||
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
||||||
parser.add_argument('--weights', type=str, default='', help='initial weights') # i.e. weights/darknet.53.conv.74
|
parser.add_argument('--weights', type=str, default='', help='initial weights') # i.e. weights/darknet.53.conv.74
|
||||||
parser.add_argument('--arc', type=str, default='default', help='yolo architecture') # default, uCE, uBCE
|
parser.add_argument('--arc', type=str, default='defaultpw', help='yolo architecture') # defaultpw, uCE, uBCE
|
||||||
parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training')
|
parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.weights = 'weights/last.pt' if opt.resume else opt.weights
|
opt.weights = 'weights/last.pt' if opt.resume else opt.weights
|
||||||
|
|
|
@ -322,8 +322,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
# Define criteria
|
# Define criteria
|
||||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
||||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
|
||||||
FBCE = nn.BCEWithLogitsLoss()
|
BCE = nn.BCEWithLogitsLoss()
|
||||||
FCE = nn.CrossEntropyLoss() # weight=model.class_weights
|
CE = nn.CrossEntropyLoss() # weight=model.class_weights
|
||||||
|
|
||||||
|
if 'F' in arc: # add focal loss
|
||||||
|
BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls), FocalLoss(BCEobj), FocalLoss(BCE), FocalLoss(CE)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
for i, pi in enumerate(p): # layer index, layer predictions
|
for i, pi in enumerate(p): # layer index, layer predictions
|
||||||
|
@ -343,7 +346,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||||
lbox += (1.0 - giou).mean() # giou loss
|
lbox += (1.0 - giou).mean() # giou loss
|
||||||
|
|
||||||
if arc == 'default' and model.nc > 1: # cls loss (only if multiple classes)
|
if 'default' in arc and model.nc > 1: # cls loss (only if multiple classes)
|
||||||
t = torch.zeros_like(ps[:, 5:]) # targets
|
t = torch.zeros_like(ps[:, 5:]) # targets
|
||||||
t[range(nb), tcls[i]] = 1.0
|
t[range(nb), tcls[i]] = 1.0
|
||||||
lcls += BCEcls(ps[:, 5:], t) # BCE
|
lcls += BCEcls(ps[:, 5:], t) # BCE
|
||||||
|
@ -353,20 +356,20 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
# with open('targets.txt', 'a') as file:
|
# with open('targets.txt', 'a') as file:
|
||||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||||
|
|
||||||
if arc == 'default':
|
if 'default' in arc: # seperate obj and cls
|
||||||
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
||||||
|
|
||||||
elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20
|
elif 'BCE' in arc: # unified BCE (80 classes)
|
||||||
t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets
|
|
||||||
if nb:
|
|
||||||
t[b, a, gj, gi] = tcls[i] + 1
|
|
||||||
lcls += FCE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1))
|
|
||||||
|
|
||||||
elif arc == 'uBCE': # unified BCE (1 background + 80 classes), hyps 200-30
|
|
||||||
t = torch.zeros_like(pi[..., 5:]) # targets
|
t = torch.zeros_like(pi[..., 5:]) # targets
|
||||||
if nb:
|
if nb:
|
||||||
t[b, a, gj, gi, tcls[i]] = 1.0
|
t[b, a, gj, gi, tcls[i]] = 1.0
|
||||||
lobj += FBCE(pi[..., 5:], t)
|
lobj += BCE(pi[..., 5:], t)
|
||||||
|
|
||||||
|
elif 'CE' in arc: # unified CE (1 background + 80 classes)
|
||||||
|
t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets
|
||||||
|
if nb:
|
||||||
|
t[b, a, gj, gi] = tcls[i] + 1
|
||||||
|
lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1))
|
||||||
|
|
||||||
lbox *= h['giou']
|
lbox *= h['giou']
|
||||||
lobj *= h['obj']
|
lobj *= h['obj']
|
||||||
|
|
Loading…
Reference in New Issue