This commit is contained in:
Glenn Jocher 2019-08-17 14:08:10 +02:00
parent 321bd95764
commit a1200ef130
3 changed files with 3 additions and 4 deletions

View File

@ -96,7 +96,7 @@ class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, yolo_index):
super(YOLOLayer, self).__init__()
self.anchors = torch.Tensor(anchors)
self.anchors = torch.from_numpy(anchors)
self.na = len(anchors) # number of anchors (3)
self.nc = nc # number of classes (80)
self.nx = 0 # initialize number of x gridpoints

View File

@ -200,8 +200,7 @@ def train(cfg,
# Start training
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
if dataset.image_weights:
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model_info(model, report='summary') # 'full' or 'summary'
nb = len(dataloader)
maps = np.zeros(nc) # mAP per class

View File

@ -62,7 +62,7 @@ def labels_to_class_weights(labels, nc=80):
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.Tensor(weights)
return torch.from_numpy(weights)
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):