updates
This commit is contained in:
parent
835f975228
commit
cbd5347cc3
|
@ -178,6 +178,7 @@ class Darknet(nn.Module):
|
|||
self.module_defs[0]['cfg'] = cfg_path
|
||||
self.module_defs[0]['height'] = img_size
|
||||
self.hyperparams, self.module_list = create_modules(self.module_defs)
|
||||
self.yolo_layers = get_yolo_layers(self)
|
||||
|
||||
def forward(self, x, var=None):
|
||||
img_size = x.shape[-1]
|
||||
|
|
3
train.py
3
train.py
|
@ -48,8 +48,7 @@ def train(
|
|||
cutoff = -1 # backbone reaches to cutoff layer
|
||||
start_epoch = 0
|
||||
best_loss = float('inf')
|
||||
yl = get_yolo_layers(model) # yolo layers
|
||||
nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||
|
||||
if resume: # Load previously saved model
|
||||
if transfer: # Transfer learning
|
||||
|
|
|
@ -288,8 +288,8 @@ def build_targets(model, targets):
|
|||
model = model.module
|
||||
|
||||
txy, twh, tcls, indices = [], [], [], []
|
||||
for i, layer in enumerate(get_yolo_layers(model)):
|
||||
layer = model.module_list[layer][0]
|
||||
for i in model.yolo_layers:
|
||||
layer = model.module_list[i][0]
|
||||
|
||||
# iou of targets-anchors
|
||||
gwh = targets[:, 4:6] * layer.nG
|
||||
|
@ -523,7 +523,7 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
|
|||
x = range(start, min(stop, n) if stop else n)
|
||||
for i in range(10):
|
||||
plt.subplot(2, 5, i + 1)
|
||||
plt.plot(x, results[i, x].clip(max=500), marker='.', label=f)
|
||||
plt.plot(x, results[i, x].clip(max=500), marker='.', label=f.replace('.txt',''))
|
||||
plt.title(s[i])
|
||||
if i == 0:
|
||||
plt.legend()
|
||||
|
|
Loading…
Reference in New Issue