KMeans
- implemented KMeans - fixed non seeded rng - fixed display exception with NoTargets - added basic test cases to app
This commit is contained in:
32
src/app.py
32
src/app.py
@@ -2,12 +2,14 @@ import random
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import sklearn
|
||||
import sklearn.cluster
|
||||
import sklearn.linear_model
|
||||
import sklearn.model_selection
|
||||
import sklearn.neural_network
|
||||
from learning.data import Dataset, TargetType
|
||||
from learning.supervised import LinearRegression, LogisticRegression, MultiLayerPerceptron
|
||||
from learning.ml import MLAlgorithm
|
||||
from learning.unsupervised import KMeans
|
||||
|
||||
DATASET = "datasets/"
|
||||
REGRESSION = DATASET + "regression/"
|
||||
@@ -75,6 +77,23 @@ def iris() -> tuple[Dataset, MLAlgorithm, Any]:
|
||||
size = [4, 3]
|
||||
return (ds, MultiLayerPerceptron(ds, size), sklearn.neural_network.MLPClassifier(size, 'relu'))
|
||||
|
||||
# ********************
|
||||
# MultiLayerPerceptron
|
||||
# ********************
|
||||
|
||||
def frogs_no_target() -> tuple[Dataset, MLAlgorithm, Any]:
|
||||
ds = Dataset(CLASSIFICATION + "frogs.csv", "Species", TargetType.NoTarget)
|
||||
ds.remove(["Family", "Genus", "RecordID", "Species"])
|
||||
clusters = 10
|
||||
return (ds, KMeans(ds, clusters), sklearn.cluster.KMeans(clusters))
|
||||
|
||||
def iris_no_target() -> tuple[Dataset, MLAlgorithm, Any]:
|
||||
ds = Dataset(CLASSIFICATION + "iris.csv", "Class", TargetType.NoTarget)
|
||||
ds.remove(["Class"])
|
||||
ds.normalize()
|
||||
clusters = 3
|
||||
return (ds, KMeans(ds, clusters), sklearn.cluster.KMeans(clusters))
|
||||
|
||||
# ********************
|
||||
# Main & random
|
||||
# ********************
|
||||
@@ -82,17 +101,24 @@ def iris() -> tuple[Dataset, MLAlgorithm, Any]:
|
||||
if __name__ == "__main__":
|
||||
np.set_printoptions(linewidth=np.inf, formatter={'float': '{:>10.5f}'.format})
|
||||
rand = random.randint(0, 4294967295)
|
||||
#rand = 1997847910 # LiR for power_plant
|
||||
#rand = 347617386 # LoR for electrical_grid
|
||||
#rand = 1793295160 # MLP for iris
|
||||
#rand = 885416001 # KMe for frogs_no_target
|
||||
|
||||
np.random.seed(rand)
|
||||
print(f"Using seed: {rand}")
|
||||
|
||||
ds, ml, sk = electrical_grid()
|
||||
ml.learn(10000, verbose=True)
|
||||
ds, ml, sk = iris()
|
||||
|
||||
epochs, _, _ = ml.learn(1000, verbose=True)
|
||||
ml.display_results()
|
||||
|
||||
np.random.seed(rand)
|
||||
learn, test, valid = ds.get_dataset()
|
||||
sk.set_params(max_iter=epochs)
|
||||
sk.fit(learn.x, learn.y)
|
||||
print(f"Sklearn : {sk.score(test.x, test.y):0.5f}")
|
||||
print(f"Sklearn : {abs(sk.score(test.x, test.y)):0.5f}")
|
||||
print("========================")
|
||||
|
||||
ml.plot()
|
||||
|
||||
Reference in New Issue
Block a user