1# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14import numpy as np15from medpy import metric16def assert_shape(test, reference):17 assert test.shape == reference.shape, "Shape mismatch: {} and {}".format(18 test.shape, reference.shape)19class ConfusionMatrix:20 def __init__(self, test=None, reference=None):21 = None22 self.fp = None23 = None24 self.fn = None25 self.size = None26 self.reference_empty = None27 self.reference_full = None28 self.test_empty = None29 self.test_full = None30 self.set_reference(reference)31 self.set_test(test)32 def set_test(self, test):33 self.test = test34 self.reset()35 def set_reference(self, reference):36 self.reference = reference37 self.reset()38 def reset(self):39 = None40 self.fp = None41 = None42 self.fn = None43 self.size = None44 self.test_empty = None45 self.test_full = None46 self.reference_empty = None47 self.reference_full = None48 def compute(self):49 if self.test is None or self.reference is None:50 raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.")51 assert_shape(self.test, self.reference)52 = int(((self.test != 0) * (self.reference != 0)).sum())53 self.fp = int(((self.test != 0) * (self.reference == 0)).sum())54 = int(((self.test == 0) * (self.reference == 0)).sum())55 self.fn = int(((self.test == 0) * (self.reference != 0)).sum())56 self.size = int(, dtype=np.int64))57 self.test_empty = not np.any(self.test)58 self.test_full = np.all(self.test)59 self.reference_empty = not np.any(self.reference)60 self.reference_full = np.all(self.reference)61 def get_matrix(self):62 for entry in (, self.fp,, self.fn):63 if entry is None:64 self.compute()65 break66 return, self.fp,, self.fn67 def get_size(self):68 if self.size is None:69 self.compute()70 return self.size71 def get_existence(self):72 for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):73 if case is None:74 self.compute()75 break76 return self.test_empty, self.test_full, self.reference_empty, self.reference_full77def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):78 """2TP / (2TP + FP + FN)"""79 if confusion_matrix is None:80 confusion_matrix = ConfusionMatrix(test, reference)81 tp, fp, tn, fn = confusion_matrix.get_matrix()82 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()83 if test_empty and reference_empty:84 if nan_for_nonexisting:85 return float("NaN")86 else:87 return 0.88 return float(2. * tp / (2 * tp + fp + fn))89def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):90 """TP / (TP + FP + FN)"""91 if confusion_matrix is None:92 confusion_matrix = ConfusionMatrix(test, reference)93 tp, fp, tn, fn = confusion_matrix.get_matrix()94 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()95 if test_empty and reference_empty:96 if nan_for_nonexisting:97 return float("NaN")98 else:99 return 0.100 return float(tp / (tp + fp + fn))101def precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):102 """TP / (TP + FP)"""103 if confusion_matrix is None:104 confusion_matrix = ConfusionMatrix(test, reference)105 tp, fp, tn, fn = confusion_matrix.get_matrix()106 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()107 if test_empty:108 if nan_for_nonexisting:109 return float("NaN")110 else:111 return 0.112 return float(tp / (tp + fp))113def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):114 """TP / (TP + FN)"""115 if confusion_matrix is None:116 confusion_matrix = ConfusionMatrix(test, reference)117 tp, fp, tn, fn = confusion_matrix.get_matrix()118 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()119 if reference_empty:120 if nan_for_nonexisting:121 return float("NaN")122 else:123 return 0.124 return float(tp / (tp + fn))125def recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):126 """TP / (TP + FN)"""127 return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs)128def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):129 """TN / (TN + FP)"""130 if confusion_matrix is None:131 confusion_matrix = ConfusionMatrix(test, reference)132 tp, fp, tn, fn = confusion_matrix.get_matrix()133 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()134 if reference_full:135 if nan_for_nonexisting:136 return float("NaN")137 else:138 return 0.139 return float(tn / (tn + fp))140def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):141 """(TP + TN) / (TP + FP + FN + TN)"""142 if confusion_matrix is None:143 confusion_matrix = ConfusionMatrix(test, reference)144 tp, fp, tn, fn = confusion_matrix.get_matrix()145 return float((tp + tn) / (tp + fp + tn + fn))146def fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs):147 """(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)"""148 precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting)149 recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting)150 return (1 + beta*beta) * precision_ * recall_ /\151 ((beta*beta * precision_) + recall_)152def false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):153 """FP / (FP + TN)"""154 return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting)155def false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):156 """FN / (TN + FN)"""157 if confusion_matrix is None:158 confusion_matrix = ConfusionMatrix(test, reference)159 tp, fp, tn, fn = confusion_matrix.get_matrix()160 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()161 if test_full:162 if nan_for_nonexisting:163 return float("NaN")164 else:165 return 0.166 return float(fn / (fn + tn))167def false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):168 """FN / (TP + FN)"""169 return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting)170def true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):171 """TN / (TN + FP)"""172 return specificity(test, reference, confusion_matrix, nan_for_nonexisting)173def false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):174 """FP / (TP + FP)"""175 return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting)176def negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):177 """TN / (TN + FN)"""178 return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting)179def total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs):180 """TP + FP"""181 if confusion_matrix is None:182 confusion_matrix = ConfusionMatrix(test, reference)183 tp, fp, tn, fn = confusion_matrix.get_matrix()184 return tp + fp185def total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs):186 """TN + FN"""187 if confusion_matrix is None:188 confusion_matrix = ConfusionMatrix(test, reference)189 tp, fp, tn, fn = confusion_matrix.get_matrix()190 return tn + fn191def total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):192 """TP + FN"""193 if confusion_matrix is None:194 confusion_matrix = ConfusionMatrix(test, reference)195 tp, fp, tn, fn = confusion_matrix.get_matrix()196 return tp + fn197def total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):198 """TN + FP"""199 if confusion_matrix is None:200 confusion_matrix = ConfusionMatrix(test, reference)201 tp, fp, tn, fn = confusion_matrix.get_matrix()202 return tn + fp203def hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):204 if confusion_matrix is None:205 confusion_matrix = ConfusionMatrix(test, reference)206 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()207 if test_empty or test_full or reference_empty or reference_full:208 if nan_for_nonexisting:209 return float("NaN")210 else:211 return 0212 test, reference = confusion_matrix.test, confusion_matrix.reference213 return metric.hd(test, reference, voxel_spacing, connectivity)214def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):215 if confusion_matrix is None:216 confusion_matrix = ConfusionMatrix(test, reference)217 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()218 if test_empty or test_full or reference_empty or reference_full:219 if nan_for_nonexisting:220 return float("NaN")221 else:222 return 0223 test, reference = confusion_matrix.test, confusion_matrix.reference224 return metric.hd95(test, reference, voxel_spacing, connectivity)225def avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):226 if confusion_matrix is None:227 confusion_matrix = ConfusionMatrix(test, reference)228 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()229 if test_empty or test_full or reference_empty or reference_full:230 if nan_for_nonexisting:231 return float("NaN")232 else:233 return 0234 test, reference = confusion_matrix.test, confusion_matrix.reference235 return metric.asd(test, reference, voxel_spacing, connectivity)236def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):237 if confusion_matrix is None:238 confusion_matrix = ConfusionMatrix(test, reference)239 test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()240 if test_empty or test_full or reference_empty or reference_full:241 if nan_for_nonexisting:242 return float("NaN")243 else:244 return 0245 test, reference = confusion_matrix.test, confusion_matrix.reference246 return metric.assd(test, reference, voxel_spacing, connectivity)247ALL_METRICS = {248 "False Positive Rate": false_positive_rate,249 "Dice": dice,250 "Jaccard": jaccard,251 "Hausdorff Distance": hausdorff_distance,252 "Hausdorff Distance 95": hausdorff_distance_95,253 "Precision": precision,254 "Recall": recall,255 "Avg. Symmetric Surface Distance": avg_surface_distance_symmetric,256 "Avg. Surface Distance": avg_surface_distance,257 "Accuracy": accuracy,258 "False Omission Rate": false_omission_rate,259 "Negative Predictive Value": negative_predictive_value,260 "False Negative Rate": false_negative_rate,261 "True Negative Rate": true_negative_rate,262 "False Discovery Rate": false_discovery_rate,263 "Total Positives Test": total_positives_test,264 "Total Negatives Test": total_negatives_test,265 "Total Positives Reference": total_positives_reference,266 "total Negatives Reference": total_negatives_reference...

