updates
This commit is contained in:
parent
a024286ec1
commit
d047062074
2
train.py
2
train.py
|
@ -51,7 +51,7 @@ def train(
|
|||
start_epoch = 0
|
||||
best_loss = float('inf')
|
||||
if resume: # Load previously saved PyTorch model
|
||||
checkpoint = torch.load(latest, map_location=device) # load checkpoin
|
||||
checkpoint = torch.load(latest, map_location=device) # load checkpoint
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
start_epoch = checkpoint['epoch'] + 1
|
||||
if checkpoint['optimizer'] is not None:
|
||||
|
|
|
@ -108,7 +108,7 @@ class LoadImagesAndLabels: # for training
|
|||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
|
||||
#self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
|
||||
return self
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
@ -133,8 +133,8 @@ class LoadImagesAndLabels: # for training
|
|||
def load_images(self, ia, ib):
|
||||
img_all, labels_all, img_paths, img_shapes = [], [], [], []
|
||||
for index, files_index in enumerate(range(ia, ib)):
|
||||
img_path = self.img_files[self.shuffled_vector[files_index]]
|
||||
label_path = self.label_files[self.shuffled_vector[files_index]]
|
||||
img_path = self.img_files[files_index]
|
||||
label_path = self.label_files[files_index]
|
||||
|
||||
img = cv2.imread(img_path) # BGR
|
||||
assert img is not None, 'File Not Found ' + img_path
|
||||
|
|
Loading…
Reference in New Issue