- added backprop
- fixed data for multiclass
- fixed confusion matrix
This commit is contained in:
2024-08-12 16:59:17 +02:00
parent c0f48e412e
commit a992539116
3 changed files with 93 additions and 51 deletions

View File

@@ -22,6 +22,9 @@ class MLAlgorithm(ABC):
self._validset = valid
self._testset = test
def with_bias(self, x:np.ndarray) -> np.ndarray:
return np.hstack([x, np.ones(shape=(x.shape[0], 1))])
def learn(self, epochs:int, early_stop:float=0.0000001, max_patience:int=10, verbose:bool=False) -> tuple[int, list, list]:
learn = []
valid = []
@@ -89,8 +92,14 @@ class MLAlgorithm(ABC):
and self._target_type != TargetType.MultiClassification:
return None
h0 = np.where(self._h0(self._testset.x) > 0.5, 1, 0)
return ConfusionMatrix(self._testset.y, h0)
h0 = self._h0(self._testset.x)
y = self._testset.y
if h0.ndim == 1:
h0 = np.where(h0 > 0.5, 1, 0)
else:
h0 = np.argmax(h0, axis=1)
y = np.argmax(y, axis=1)
return ConfusionMatrix(y, h0)
def test_r_squared(self) -> float:
if self._target_type != TargetType.Regression: