This commit is contained in:
2024-04-20 21:21:45 +02:00
parent 18e390d34b
commit f525cdf280
4 changed files with 46 additions and 28 deletions

View File

@@ -1,14 +1,18 @@
from abc import ABC, abstractmethod
from learning.data import Dataset
from plot import Plot
import numpy as np
class MLAlgorithm(ABC):
""" Classe generica per gli algoritmi di Machine Learning """
dataset: Dataset
testset: np.ndarray
learnset: np.ndarray
test_error: list[float]
train_error: list[float]
def _set_dataset(self, dataset:Dataset, split:float=0.2):
ndarray = dataset.shuffle().as_ndarray()
@@ -30,6 +34,9 @@ class MLAlgorithm(ABC):
for _ in range(0, max(1, times)):
train.append(self.learning_step())
test.append(self.test_error())
self.train_error = train
self.test_error = test
return (train, test)
@abstractmethod
@@ -39,3 +46,15 @@ class MLAlgorithm(ABC):
@abstractmethod
def test_error(self) -> float:
pass
@abstractmethod
def plot(self, skip:int=1000) -> None:
pass
class MLRegression(MLAlgorithm):
def plot(self, skip:int=1000) -> None:
plot = Plot("Error", "Time", "Mean Error")
plot.line("training", "red", data=self.train_error[skip:])
plot.line("test", "blue", data=self.test_error[skip:])
plot.wait()