This commit is contained in:
Glenn Jocher 2019-03-22 15:08:03 +02:00
parent 75d8cbdd5f
commit b31f8fb017
2 changed files with 7 additions and 8 deletions

View File

@ -64,6 +64,7 @@ def train(
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
print('WARNING: MultiGPU Issue: https://github.com/ultralytics/yolov3/issues/146')
model = nn.DataParallel(model) model = nn.DataParallel(model)
# Transfer learning (train only YOLO layers) # Transfer learning (train only YOLO layers)
@ -88,10 +89,7 @@ def train(
# scheduler.step() # scheduler.step()
# Update scheduler (manual) # Update scheduler (manual)
if epoch > 250: lr = lr0 / 10 if epoch > 250 else lr0
lr = lr0 / 10
else:
lr = lr0
for x in optimizer.param_groups: for x in optimizer.param_groups:
x['lr'] = lr x['lr'] = lr
@ -119,7 +117,7 @@ def train(
plt.figure(figsize=(10, 10)) plt.figure(figsize=(10, 10))
for ip in range(batch_size): for ip in range(batch_size):
labels = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy() * img_size labels = xywh2xyxy(targets[targets[:, 0] == ip, 2:6]).numpy() * img_size
plt.subplot(3, 3, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0)) plt.subplot(4, 4, ip + 1).imshow(imgs[ip].numpy().transpose(1, 2, 0))
plt.plot(labels[:, [0, 2, 2, 0, 0]].T, labels[:, [1, 1, 3, 3, 1]].T, '.-') plt.plot(labels[:, [0, 2, 2, 0, 0]].T, labels[:, [1, 1, 3, 3, 1]].T, '.-')
plt.axis('off') plt.axis('off')

View File

@ -9,14 +9,15 @@ sudo shutdown
# Start # Start
sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3 sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
cp -r weights yolov3 cp -r weights yolov3
cd yolov3 && python3 train.py --batch-size 16 --num-workers 4 cd yolov3 && python3 train.py --batch-size 16 --epochs 1
sudo shutdown
# Resume # Resume
python3 train.py --resume python3 train.py --resume
# Detect # Detect
gsutil cp gs://ultralytics/yolov3.pt yolov3/weights sudo rm -rf yolov3 && git clone https://github.com/ultralytics/yolov3
python3 detect.py cd yolov3 && python3 detect.py
# Clone branch # Clone branch
sudo rm -rf yolov3 && git clone -b multi_gpu --depth 1 https://github.com/ultralytics/yolov3 sudo rm -rf yolov3 && git clone -b multi_gpu --depth 1 https://github.com/ultralytics/yolov3