updates
This commit is contained in:
parent
321bd95764
commit
a1200ef130
|
@ -96,7 +96,7 @@ class YOLOLayer(nn.Module):
|
||||||
def __init__(self, anchors, nc, img_size, yolo_index):
|
def __init__(self, anchors, nc, img_size, yolo_index):
|
||||||
super(YOLOLayer, self).__init__()
|
super(YOLOLayer, self).__init__()
|
||||||
|
|
||||||
self.anchors = torch.Tensor(anchors)
|
self.anchors = torch.from_numpy(anchors)
|
||||||
self.na = len(anchors) # number of anchors (3)
|
self.na = len(anchors) # number of anchors (3)
|
||||||
self.nc = nc # number of classes (80)
|
self.nc = nc # number of classes (80)
|
||||||
self.nx = 0 # initialize number of x gridpoints
|
self.nx = 0 # initialize number of x gridpoints
|
||||||
|
|
3
train.py
3
train.py
|
@ -200,8 +200,7 @@ def train(cfg,
|
||||||
# Start training
|
# Start training
|
||||||
model.nc = nc # attach number of classes to model
|
model.nc = nc # attach number of classes to model
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
if dataset.image_weights:
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
|
||||||
model_info(model, report='summary') # 'full' or 'summary'
|
model_info(model, report='summary') # 'full' or 'summary'
|
||||||
nb = len(dataloader)
|
nb = len(dataloader)
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
|
|
|
@ -62,7 +62,7 @@ def labels_to_class_weights(labels, nc=80):
|
||||||
weights[weights == 0] = 1 # replace empty bins with 1
|
weights[weights == 0] = 1 # replace empty bins with 1
|
||||||
weights = 1 / weights # number of targets per class
|
weights = 1 / weights # number of targets per class
|
||||||
weights /= weights.sum() # normalize
|
weights /= weights.sum() # normalize
|
||||||
return torch.Tensor(weights)
|
return torch.from_numpy(weights)
|
||||||
|
|
||||||
|
|
||||||
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
||||||
|
|
Loading…
Reference in New Issue