Fix bugs
This commit is contained in:
@@ -1,14 +1,18 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from learning.data import Dataset
|
||||
from plot import Plot
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MLAlgorithm(ABC):
|
||||
""" Classe generica per gli algoritmi di Machine Learning """
|
||||
|
||||
dataset: Dataset
|
||||
testset: np.ndarray
|
||||
learnset: np.ndarray
|
||||
test_error: list[float]
|
||||
train_error: list[float]
|
||||
|
||||
def _set_dataset(self, dataset:Dataset, split:float=0.2):
|
||||
ndarray = dataset.shuffle().as_ndarray()
|
||||
@@ -30,6 +34,9 @@ class MLAlgorithm(ABC):
|
||||
for _ in range(0, max(1, times)):
|
||||
train.append(self.learning_step())
|
||||
test.append(self.test_error())
|
||||
|
||||
self.train_error = train
|
||||
self.test_error = test
|
||||
return (train, test)
|
||||
|
||||
@abstractmethod
|
||||
@@ -39,3 +46,15 @@ class MLAlgorithm(ABC):
|
||||
@abstractmethod
|
||||
def test_error(self) -> float:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def plot(self, skip:int=1000) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MLRegression(MLAlgorithm):
|
||||
def plot(self, skip:int=1000) -> None:
|
||||
plot = Plot("Error", "Time", "Mean Error")
|
||||
plot.line("training", "red", data=self.train_error[skip:])
|
||||
plot.line("test", "blue", data=self.test_error[skip:])
|
||||
plot.wait()
|
||||
|
||||
Reference in New Issue
Block a user