This commit is contained in:
Glenn Jocher 2019-04-11 12:41:07 +02:00
parent 835f975228
commit cbd5347cc3
3 changed files with 5 additions and 5 deletions

View File

@ -178,6 +178,7 @@ class Darknet(nn.Module):
self.module_defs[0]['cfg'] = cfg_path
self.module_defs[0]['height'] = img_size
self.hyperparams, self.module_list = create_modules(self.module_defs)
self.yolo_layers = get_yolo_layers(self)
def forward(self, x, var=None):
img_size = x.shape[-1]

View File

@ -48,8 +48,7 @@ def train(
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
best_loss = float('inf')
yl = get_yolo_layers(model) # yolo layers
nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255)
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
if resume: # Load previously saved model
if transfer: # Transfer learning

View File

@ -288,8 +288,8 @@ def build_targets(model, targets):
model = model.module
txy, twh, tcls, indices = [], [], [], []
for i, layer in enumerate(get_yolo_layers(model)):
layer = model.module_list[layer][0]
for i in model.yolo_layers:
layer = model.module_list[i][0]
# iou of targets-anchors
gwh = targets[:, 4:6] * layer.nG
@ -523,7 +523,7 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
x = range(start, min(stop, n) if stop else n)
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.plot(x, results[i, x].clip(max=500), marker='.', label=f)
plt.plot(x, results[i, x].clip(max=500), marker='.', label=f.replace('.txt',''))
plt.title(s[i])
if i == 0:
plt.legend()