Fixed display stats
This commit is contained in:
@@ -55,7 +55,7 @@ def heart() -> tuple[Dataset, MLAlgorithm, Any]:
|
|||||||
attributes_to_modify = ["Disease", "Sex", "ChestPainType"]
|
attributes_to_modify = ["Disease", "Sex", "ChestPainType"]
|
||||||
ds.factorize(attributes_to_modify)
|
ds.factorize(attributes_to_modify)
|
||||||
ds.normalize(excepts=attributes_to_modify)
|
ds.normalize(excepts=attributes_to_modify)
|
||||||
return (ds, LogisticRegression(ds, learning_rate=0.001), sklearn.linear_model.LogisticRegression())
|
return (ds, LogisticRegression(ds, learning_rate=0.01), sklearn.linear_model.LogisticRegression())
|
||||||
|
|
||||||
# ********************
|
# ********************
|
||||||
# MultiLayerPerceptron
|
# MultiLayerPerceptron
|
||||||
|
|||||||
@@ -47,7 +47,9 @@ class Dataset:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def normalize(self, excepts:list[str]=[]) -> Self:
|
def normalize(self, excepts:list[str]=[]) -> Self:
|
||||||
excepts.append(self.target)
|
if excepts is None: excepts = []
|
||||||
|
else: excepts.append(self.target)
|
||||||
|
|
||||||
for col in self.data:
|
for col in self.data:
|
||||||
if col not in excepts:
|
if col not in excepts:
|
||||||
index = self.data.columns.get_loc(col)
|
index = self.data.columns.get_loc(col)
|
||||||
@@ -134,6 +136,32 @@ class ConfusionMatrix:
|
|||||||
tn = total - (tp + fp + fn)
|
tn = total - (tp + fp + fn)
|
||||||
return tn / (tn + fp)
|
return tn / (tn + fp)
|
||||||
|
|
||||||
|
def accuracy(self) -> float:
|
||||||
|
tp = np.diag(self.matrix).sum()
|
||||||
|
total = self.matrix.sum()
|
||||||
|
return tp / total
|
||||||
|
|
||||||
|
def precision(self) -> float:
|
||||||
|
precision_per_class = self.precision_per_class()
|
||||||
|
support = np.sum(self.matrix, axis=1)
|
||||||
|
return np.average(precision_per_class, weights=support)
|
||||||
|
|
||||||
|
def recall(self) -> float:
|
||||||
|
recall_per_class = self.recall_per_class()
|
||||||
|
support = np.sum(self.matrix, axis=1)
|
||||||
|
return np.average(recall_per_class, weights=support)
|
||||||
|
|
||||||
|
def f1_score(self) -> float:
|
||||||
|
f1_per_class = self.f1_score_per_class()
|
||||||
|
support = np.sum(self.matrix, axis=1)
|
||||||
|
return np.average(f1_per_class, weights=support)
|
||||||
|
|
||||||
|
def specificity(self) -> float:
|
||||||
|
specificity_per_class = self.specificity_per_class()
|
||||||
|
support = np.sum(self.matrix, axis=1)
|
||||||
|
return np.average(specificity_per_class, weights=support)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ds = Dataset("datasets\\classification\\frogs.csv", "Species", TargetType.MultiClassification)
|
ds = Dataset("datasets\\classification\\frogs.csv", "Species", TargetType.MultiClassification)
|
||||||
ds.remove(["Family", "Genus", "RecordID"])
|
ds.remove(["Family", "Genus", "RecordID"])
|
||||||
|
|||||||
@@ -78,11 +78,11 @@ class MLAlgorithm(ABC):
|
|||||||
print(f"R^2 : {self.test_r_squared():0.5f}")
|
print(f"R^2 : {self.test_r_squared():0.5f}")
|
||||||
else:
|
else:
|
||||||
conf = self.test_confusion_matrix()
|
conf = self.test_confusion_matrix()
|
||||||
print(f"Accuracy : {conf.accuracy_per_class()}")
|
print(f"Accuracy : {conf.accuracy():0.5f} - classes {conf.accuracy_per_class()}")
|
||||||
print(f"Precision : {conf.precision_per_class()}")
|
print(f"Precision : {conf.precision():0.5f} - classes {conf.precision_per_class()}")
|
||||||
print(f"Recall : {conf.recall_per_class()}")
|
print(f"Recall : {conf.recall():0.5f} - classes {conf.recall_per_class()}")
|
||||||
print(f"F1 score : {conf.f1_score_per_class()}")
|
print(f"F1 score : {conf.f1_score():0.5f} - classes {conf.f1_score_per_class()}")
|
||||||
print(f"Specificity: {conf.specificity_per_class()}")
|
print(f"Specificity: {conf.specificity():0.5f} - classes {conf.specificity_per_class()}")
|
||||||
|
|
||||||
def test_confusion_matrix(self) -> ConfusionMatrix:
|
def test_confusion_matrix(self) -> ConfusionMatrix:
|
||||||
if self._target_type != TargetType.Classification\
|
if self._target_type != TargetType.Classification\
|
||||||
|
|||||||
Reference in New Issue
Block a user