This commit is contained in:
Glenn Jocher 2019-07-25 13:19:26 +02:00
parent 169d117870
commit 7b6cba86ef
5 changed files with 10 additions and 17 deletions

View File

@ -1,7 +1,3 @@
import os
import torch.nn.functional as F
from utils.parse_config import *
from utils.utils import *

View File

@ -1,15 +1,15 @@
import argparse
import time
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.distributed as dist
import test # import test.py to get mAP after each epoch
from models import *
from utils.adabound import *
from utils.datasets import *
from utils.utils import *
from utils.adabound import *
mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex

View File

@ -1,4 +1,5 @@
import math
import torch
from torch.optim import Optimizer

View File

@ -8,9 +8,9 @@ from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm
from PIL import Image, ExifTags
from utils.utils import xyxy2xywh, xywh2xyxy
@ -154,8 +154,7 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False):
with open(path, 'r') as f:
img_files = f.read().splitlines()
self.img_files = [x for x in img_files if os.path.splitext(x)[-1].lower() in img_formats]
self.img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats]
n = len(self.img_files)
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index

View File

@ -1,5 +1,7 @@
import glob
import os
import random
from pathlib import Path
import cv2
import matplotlib
@ -9,7 +11,6 @@ import torch
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from . import torch_utils # , google_utils
@ -543,18 +544,14 @@ def select_best_evolve(path='evolve*.txt'): # from utils.utils import *; select
def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *; kmeans_targets()
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif']
with open(path, 'r') as f:
img_files = f.read().splitlines()
img_files = list(filter(lambda x: len(x) > 0, img_files))
img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats]
# Read shapes
n = len(img_files)
assert n > 0, 'No images found in %s' % path
label_files = [x.replace('images', 'labels').
replace('.jpeg', '.txt').
replace('.jpg', '.txt').
replace('.bmp', '.txt').
replace('.png', '.txt') for x in img_files]
label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in img_files]
s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height)
# Read targets