updates
This commit is contained in:
parent
8e327e3bd0
commit
bb209111c4
3
train.py
3
train.py
|
@ -204,8 +204,7 @@ def train():
|
||||||
model.nc = nc # attach number of classes to model
|
model.nc = nc # attach number of classes to model
|
||||||
model.arc = opt.arc # attach yolo architecture
|
model.arc = opt.arc # attach yolo architecture
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
if hasattr(dataset, 'labels'):
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
|
||||||
torch_utils.model_info(model, report='summary') # 'full' or 'summary'
|
torch_utils.model_info(model, report='summary') # 'full' or 'summary'
|
||||||
nb = len(dataloader)
|
nb = len(dataloader)
|
||||||
maps = np.zeros(nc) # mAP per class
|
maps = np.zeros(nc) # mAP per class
|
||||||
|
|
|
@ -43,14 +43,16 @@ def load_classes(path):
|
||||||
|
|
||||||
def labels_to_class_weights(labels, nc=80):
|
def labels_to_class_weights(labels, nc=80):
|
||||||
# Get class weights (inverse frequency) from training labels
|
# Get class weights (inverse frequency) from training labels
|
||||||
ni = len(labels) # number of images
|
if labels[0] is None: # no labels loaded
|
||||||
|
return None
|
||||||
|
|
||||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||||
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
|
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
|
||||||
weights = np.bincount(classes, minlength=nc) # occurences per class
|
weights = np.bincount(classes, minlength=nc) # occurences per class
|
||||||
|
|
||||||
# Prepend gridpoint count (for uCE trianing)
|
# Prepend gridpoint count (for uCE trianing)
|
||||||
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
|
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
|
||||||
# weights = np.hstack([gpi * ni - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
|
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
|
||||||
|
|
||||||
weights[weights == 0] = 1 # replace empty bins with 1
|
weights[weights == 0] = 1 # replace empty bins with 1
|
||||||
weights = 1 / weights # number of targets per class
|
weights = 1 / weights # number of targets per class
|
||||||
|
|
Loading…
Reference in New Issue