This commit is contained in:
Glenn Jocher 2018-11-22 15:04:02 +01:00
parent b9d87be318
commit 154fae4430
1 changed files with 22 additions and 19 deletions

View File

@ -254,28 +254,31 @@ class Darknet(nn.Module):
output.append(x) output.append(x)
layer_outputs.append(x) layer_outputs.append(x)
self.losses['nT'] /= 3 if is_training:
self.losses['TP'] = 0 if batch_report:
self.losses['FP'] = 0 self.losses['TC'] /= 3 # target category
self.losses['FN'] = 0 metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN
if is_training and batch_report:
self.losses['TC'] /= 3 # target category
metrics = torch.zeros(3, len(self.losses['FPe'])) # TP, FP, FN
ui = np.unique(self.losses['TC'])[1:] ui = np.unique(self.losses['TC'])[1:]
for i in ui: for i in ui:
j = self.losses['TC'] == float(i) j = self.losses['TC'] == float(i)
metrics[0, i] = (self.losses['TP'][j] > 0).sum().float() # TP metrics[0, i] = (self.losses['TP'][j] > 0).sum().float() # TP
metrics[1, i] = (self.losses['FP'][j] > 0).sum().float() # FP metrics[1, i] = (self.losses['FP'][j] > 0).sum().float() # FP
metrics[2, i] = (self.losses['FN'][j] == 3).sum().float() # FN metrics[2, i] = (self.losses['FN'][j] == 3).sum().float() # FN
metrics[1] += self.losses['FPe'] metrics[1] += self.losses['FPe']
self.losses['TP'] = metrics[0].sum() self.losses['TP'] = metrics[0].sum()
self.losses['FP'] = metrics[1].sum() self.losses['FP'] = metrics[1].sum()
self.losses['FN'] = metrics[2].sum() self.losses['FN'] = metrics[2].sum()
self.losses['metrics'] = metrics self.losses['metrics'] = metrics
else:
self.losses['TP'] = 0
self.losses['FP'] = 0
self.losses['FN'] = 0
self.losses['nT'] /= 3
self.losses['TC'] = 0
self.losses['TC'] = 0
return sum(output) if is_training else torch.cat(output, 1) return sum(output) if is_training else torch.cat(output, 1)