updates
This commit is contained in:
parent
b9d87be318
commit
154fae4430
41
models.py
41
models.py
|
@ -254,28 +254,31 @@ class Darknet(nn.Module):
|
||||||
output.append(x)
|
output.append(x)
|
||||||
layer_outputs.append(x)
|
layer_outputs.append(x)
|
||||||
|
|
||||||
self.losses['nT'] /= 3
|
if is_training:
|
||||||
self.losses['TP'] = 0
|
if batch_report:
|
||||||
self.losses['FP'] = 0
|
self.losses['TC'] /= 3 # target category
|
||||||
self.losses['FN'] = 0
|
metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN
|
||||||
if is_training and batch_report:
|
|
||||||
self.losses['TC'] /= 3 # target category
|
|
||||||
metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN
|
|
||||||
|
|
||||||
ui = np.unique(self.losses['TC'])[1:]
|
ui = np.unique(self.losses['TC'])[1:]
|
||||||
for i in ui:
|
for i in ui:
|
||||||
j = self.losses['TC'] == float(i)
|
j = self.losses['TC'] == float(i)
|
||||||
metrics[0, i] = (self.losses['TP'][j] > 0).sum().float() # TP
|
metrics[0, i] = (self.losses['TP'][j] > 0).sum().float() # TP
|
||||||
metrics[1, i] = (self.losses['FP'][j] > 0).sum().float() # FP
|
metrics[1, i] = (self.losses['FP'][j] > 0).sum().float() # FP
|
||||||
metrics[2, i] = (self.losses['FN'][j] == 3).sum().float() # FN
|
metrics[2, i] = (self.losses['FN'][j] == 3).sum().float() # FN
|
||||||
metrics[1] += self.losses['FPe']
|
metrics[1] += self.losses['FPe']
|
||||||
|
|
||||||
self.losses['TP'] = metrics[0].sum()
|
self.losses['TP'] = metrics[0].sum()
|
||||||
self.losses['FP'] = metrics[1].sum()
|
self.losses['FP'] = metrics[1].sum()
|
||||||
self.losses['FN'] = metrics[2].sum()
|
self.losses['FN'] = metrics[2].sum()
|
||||||
self.losses['metrics'] = metrics
|
self.losses['metrics'] = metrics
|
||||||
|
else:
|
||||||
|
self.losses['TP'] = 0
|
||||||
|
self.losses['FP'] = 0
|
||||||
|
self.losses['FN'] = 0
|
||||||
|
|
||||||
|
self.losses['nT'] /= 3
|
||||||
|
self.losses['TC'] = 0
|
||||||
|
|
||||||
self.losses['TC'] = 0
|
|
||||||
return sum(output) if is_training else torch.cat(output, 1)
|
return sum(output) if is_training else torch.cat(output, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue