Refactor
- renamed many variables - results better displayed - fixed log(0) error with 1e-15
This commit is contained in:
@@ -73,7 +73,7 @@ class Dataset:
|
||||
self.data = self.data.dropna()
|
||||
return self
|
||||
|
||||
def get_dataset(self, test_frac:float=0.15, valid_frac:float=0.15) -> tuple[Data, Data, Data]:
|
||||
def get_dataset(self, test_frac:float=0.2, valid_frac:float=0.2) -> tuple[Data, Data, Data]:
|
||||
data = self.data.to_numpy()
|
||||
data = np.insert(data, 1, 1, axis=1) # adding bias
|
||||
np.random.shuffle(data)
|
||||
@@ -97,6 +97,43 @@ class Dataset:
|
||||
l.append(Data(ds, target))
|
||||
return l
|
||||
|
||||
class ConfusionMatrix:
|
||||
matrix:np.ndarray
|
||||
|
||||
def __init__(self, dataset_y: np.ndarray, predictions_y:np.ndarray) -> None:
|
||||
classes = len(np.unique(dataset_y))
|
||||
conf_matrix = np.zeros((classes, classes), dtype=int)
|
||||
|
||||
for actual, prediction in zip(dataset_y, predictions_y):
|
||||
conf_matrix[int(actual), int(prediction)] += 1
|
||||
self.matrix = conf_matrix
|
||||
|
||||
def accuracy_per_class(self) -> np.ndarray:
|
||||
return np.diag(self.matrix) / np.sum(self.matrix, axis=1)
|
||||
|
||||
def precision_per_class(self) -> np.ndarray:
|
||||
tp = np.diagonal(self.matrix)
|
||||
fp = np.sum(self.matrix, axis=0) - tp
|
||||
return tp / (tp + fp)
|
||||
|
||||
def recall_per_class(self) -> np.ndarray:
|
||||
tp = np.diagonal(self.matrix)
|
||||
fn = np.sum(self.matrix, axis=1) - tp
|
||||
return tp / (tp + fn)
|
||||
|
||||
def f1_score_per_class(self) -> np.ndarray:
|
||||
prec = self.precision_per_class()
|
||||
rec = self.recall_per_class()
|
||||
return 2 * (prec * rec) / (prec + rec)
|
||||
|
||||
def specificity_per_class(self) -> np.ndarray:
|
||||
total = np.sum(self.matrix)
|
||||
tp = np.diagonal(self.matrix)
|
||||
fp = np.sum(self.matrix, axis=0) - tp
|
||||
fn = np.sum(self.matrix, axis=1) - tp
|
||||
tn = total - (tp + fp + fn)
|
||||
return tn / (tn + fp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds = Dataset("datasets\\classification\\frogs.csv", "Species", TargetType.MultiClassification)
|
||||
ds.remove(["Family", "Genus", "RecordID"])
|
||||
|
||||
Reference in New Issue
Block a user