updates
This commit is contained in:
parent
76c45f4ed9
commit
55077b2770
1
train.py
1
train.py
|
@ -145,6 +145,7 @@ def train(
|
||||||
# Start training
|
# Start training
|
||||||
t, t0 = time.time(), time.time()
|
t, t0 = time.time(), time.time()
|
||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
|
model.class_weights = labels_to_class_weights(dataset.labels).to(device) # attach class weights
|
||||||
model_info(model)
|
model_info(model)
|
||||||
nb = len(dataloader)
|
nb = len(dataloader)
|
||||||
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss
|
||||||
|
|
|
@ -174,9 +174,20 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32
|
self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32
|
||||||
self.batch = bi # batch index of image
|
self.batch = bi # batch index of image
|
||||||
|
|
||||||
|
# Preload images
|
||||||
# if n < 200: # preload all images into memory if possible
|
# if n < 200: # preload all images into memory if possible
|
||||||
# self.imgs = [cv2.imread(img_files[i]) for i in range(n)]
|
# self.imgs = [cv2.imread(img_files[i]) for i in range(n)]
|
||||||
|
|
||||||
|
# Preload labels (required for weighted CE training)
|
||||||
|
self.labels = [np.array([])] * n
|
||||||
|
iter = tqdm(self.label_files, desc='Reading labels') if n > 5000 else self.label_files
|
||||||
|
for i, file in enumerate(iter):
|
||||||
|
try:
|
||||||
|
with open(file, 'r') as f:
|
||||||
|
self.labels[i] = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
|
||||||
|
except: # missing label file
|
||||||
|
pass
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.img_files)
|
return len(self.img_files)
|
||||||
|
|
||||||
|
@ -217,10 +228,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
# Load labels
|
# Load labels
|
||||||
labels = []
|
labels = []
|
||||||
if os.path.isfile(label_path):
|
if os.path.isfile(label_path):
|
||||||
with open(label_path, 'r') as file:
|
# with open(label_path, 'r') as f:
|
||||||
lines = file.read().splitlines()
|
# x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
|
||||||
|
x = self.labels[index]
|
||||||
x = np.array([x.split() for x in lines], dtype=np.float32)
|
|
||||||
if x.size > 0:
|
if x.size > 0:
|
||||||
# Normalized xywh to pixel xyxy format
|
# Normalized xywh to pixel xyxy format
|
||||||
labels = x.copy()
|
labels = x.copy()
|
||||||
|
|
|
@ -49,6 +49,15 @@ def model_info(model):
|
||||||
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
|
print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g))
|
||||||
|
|
||||||
|
|
||||||
|
def labels_to_class_weights(labels):
|
||||||
|
# Get class weights (inverse frequency) from training labels
|
||||||
|
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||||
|
classes = labels[:, 0].astype(np.int)
|
||||||
|
weights = 1 / (np.bincount(classes) + 1e-6) # number of targets per class
|
||||||
|
weights /= weights.sum()
|
||||||
|
return torch.Tensor(weights)
|
||||||
|
|
||||||
|
|
||||||
def coco_class_weights(): # frequency of each class in coco train2014
|
def coco_class_weights(): # frequency of each class in coco train2014
|
||||||
weights = 1 / torch.FloatTensor(
|
weights = 1 / torch.FloatTensor(
|
||||||
[187437, 4955, 30920, 6033, 3838, 4332, 3160, 7051, 7677, 9167, 1316, 1372, 833, 6757, 7355, 3302, 3776, 4671,
|
[187437, 4955, 30920, 6033, 3838, 4332, 3160, 7051, 7677, 9167, 1316, 1372, 833, 6757, 7355, 3302, 3776, 4671,
|
||||||
|
@ -247,7 +256,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
|
|
||||||
# Define criteria
|
# Define criteria
|
||||||
MSE = nn.MSELoss()
|
MSE = nn.MSELoss()
|
||||||
CE = nn.CrossEntropyLoss()
|
CE = nn.CrossEntropyLoss(weight=model.class_weights)
|
||||||
BCE = nn.BCEWithLogitsLoss()
|
BCE = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
|
|
Loading…
Reference in New Issue