updates
This commit is contained in:
parent
169d117870
commit
7b6cba86ef
|
@ -1,7 +1,3 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
4
train.py
4
train.py
|
@ -1,15 +1,15 @@
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
import test # import test.py to get mAP after each epoch
|
import test # import test.py to get mAP after each epoch
|
||||||
from models import *
|
from models import *
|
||||||
|
from utils.adabound import *
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
from utils.adabound import *
|
|
||||||
|
|
||||||
mixed_precision = True
|
mixed_precision = True
|
||||||
try: # Mixed precision training https://github.com/NVIDIA/apex
|
try: # Mixed precision training https://github.com/NVIDIA/apex
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
|
@ -8,9 +8,9 @@ from pathlib import Path
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image, ExifTags
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image, ExifTags
|
|
||||||
|
|
||||||
from utils.utils import xyxy2xywh, xywh2xyxy
|
from utils.utils import xyxy2xywh, xywh2xyxy
|
||||||
|
|
||||||
|
@ -154,8 +154,7 @@ class LoadWebcam: # for inference
|
||||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||||
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False):
|
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False):
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
img_files = f.read().splitlines()
|
self.img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats]
|
||||||
self.img_files = [x for x in img_files if os.path.splitext(x)[-1].lower() in img_formats]
|
|
||||||
|
|
||||||
n = len(self.img_files)
|
n = len(self.img_files)
|
||||||
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
|
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import glob
|
import glob
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
@ -9,7 +11,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from . import torch_utils # , google_utils
|
from . import torch_utils # , google_utils
|
||||||
|
|
||||||
|
@ -543,18 +544,14 @@ def select_best_evolve(path='evolve*.txt'): # from utils.utils import *; select
|
||||||
|
|
||||||
|
|
||||||
def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *; kmeans_targets()
|
def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *; kmeans_targets()
|
||||||
|
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif']
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
img_files = f.read().splitlines()
|
img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats]
|
||||||
img_files = list(filter(lambda x: len(x) > 0, img_files))
|
|
||||||
|
|
||||||
# Read shapes
|
# Read shapes
|
||||||
n = len(img_files)
|
n = len(img_files)
|
||||||
assert n > 0, 'No images found in %s' % path
|
assert n > 0, 'No images found in %s' % path
|
||||||
label_files = [x.replace('images', 'labels').
|
label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in img_files]
|
||||||
replace('.jpeg', '.txt').
|
|
||||||
replace('.jpg', '.txt').
|
|
||||||
replace('.bmp', '.txt').
|
|
||||||
replace('.png', '.txt') for x in img_files]
|
|
||||||
s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height)
|
s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height)
|
||||||
|
|
||||||
# Read targets
|
# Read targets
|
||||||
|
|
Loading…
Reference in New Issue