updates
This commit is contained in:
parent
835f975228
commit
cbd5347cc3
|
@ -178,6 +178,7 @@ class Darknet(nn.Module):
|
||||||
self.module_defs[0]['cfg'] = cfg_path
|
self.module_defs[0]['cfg'] = cfg_path
|
||||||
self.module_defs[0]['height'] = img_size
|
self.module_defs[0]['height'] = img_size
|
||||||
self.hyperparams, self.module_list = create_modules(self.module_defs)
|
self.hyperparams, self.module_list = create_modules(self.module_defs)
|
||||||
|
self.yolo_layers = get_yolo_layers(self)
|
||||||
|
|
||||||
def forward(self, x, var=None):
|
def forward(self, x, var=None):
|
||||||
img_size = x.shape[-1]
|
img_size = x.shape[-1]
|
||||||
|
|
3
train.py
3
train.py
|
@ -48,8 +48,7 @@ def train(
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
yl = get_yolo_layers(model) # yolo layers
|
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||||
nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
|
||||||
|
|
||||||
if resume: # Load previously saved model
|
if resume: # Load previously saved model
|
||||||
if transfer: # Transfer learning
|
if transfer: # Transfer learning
|
||||||
|
|
|
@ -288,8 +288,8 @@ def build_targets(model, targets):
|
||||||
model = model.module
|
model = model.module
|
||||||
|
|
||||||
txy, twh, tcls, indices = [], [], [], []
|
txy, twh, tcls, indices = [], [], [], []
|
||||||
for i, layer in enumerate(get_yolo_layers(model)):
|
for i in model.yolo_layers:
|
||||||
layer = model.module_list[layer][0]
|
layer = model.module_list[i][0]
|
||||||
|
|
||||||
# iou of targets-anchors
|
# iou of targets-anchors
|
||||||
gwh = targets[:, 4:6] * layer.nG
|
gwh = targets[:, 4:6] * layer.nG
|
||||||
|
@ -523,7 +523,7 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
|
||||||
x = range(start, min(stop, n) if stop else n)
|
x = range(start, min(stop, n) if stop else n)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
plt.subplot(2, 5, i + 1)
|
plt.subplot(2, 5, i + 1)
|
||||||
plt.plot(x, results[i, x].clip(max=500), marker='.', label=f)
|
plt.plot(x, results[i, x].clip(max=500), marker='.', label=f.replace('.txt',''))
|
||||||
plt.title(s[i])
|
plt.title(s[i])
|
||||||
if i == 0:
|
if i == 0:
|
||||||
plt.legend()
|
plt.legend()
|
||||||
|
|
Loading…
Reference in New Issue