Refactor Dataset

- better finalize function
- support for one-hot-encoding
This commit is contained in:
2024-05-02 14:19:23 +02:00
parent 969338196b
commit 3a4e07afc8
4 changed files with 118 additions and 98 deletions

View File

@@ -1,33 +1,53 @@
import pandas as pd
import numpy as np
from enum import Enum
from typing_extensions import Self
class TargetType(Enum):
Regression = 1
Classification = 2
MultiClassification = 3
NoTarget = 4
class Data:
x: np.ndarray
y: np.ndarray
size: int
param: int
def __init__(self, x:np.ndarray, y:np.ndarray) -> None:
self.x = x
self.y = y
self.size = x.shape[0]
self.param = x.shape[1]
def __str__(self) -> str:
return "X: " + str(self.x) + "\nY: " + str(self.y)
def as_tuple(self) -> tuple[np.ndarray, np.ndarray, int, int]:
return (self.x, self.y, self.size, self.param)
class Dataset:
def __init__(self, csv:str, target:str, classification:bool=None) -> None:
data = pd.read_csv(csv)
data: pd.DataFrame
target: str
target_type: TargetType
def __init__(self, csv:str, target:str, target_type:TargetType) -> None:
self.original = pd.read_csv(csv)
self.data = self.original
self.target = target
self.target_type = target_type
# move target to the start
col_target = data.pop(target)
data.insert(0, target, col_target)
data.insert(1, "Bias", 1.0)
if classification == None:
classification = (data[target].dtype == object)
self.original = data
self.data = data
self.target = target
self.classification = classification
col_target = self.data.pop(target)
self.data.insert(0, target, col_target)
def remove(self, columns:list[str]) -> Self:
for col in columns:
self.data.pop(col)
return self
def regularize(self, excepts:list[str]=[]) -> Self:
def normalize(self, excepts:list[str]=[]) -> Self:
excepts.append(self.target)
excepts.append("Bias")
for col in self.data:
if col not in excepts:
index = self.data.columns.get_loc(col)
@@ -42,7 +62,7 @@ class Dataset:
data[col] = pd.factorize(data[col])[0]
return self
def to_numbers(self, columns:list[str]=[]) -> Self:
def numbers(self, columns:list[str]=[]) -> Self:
data = self.data
for col in columns:
if data[col].dtype == object:
@@ -53,34 +73,38 @@ class Dataset:
self.data = self.data.dropna()
return self
def shuffle(self) -> Self:
self.data = self.data.sample(frac=1)
return self
def get_dataset(self, test_frac:float=0.15, valid_frac:float=0.15) -> tuple[Data, Data, Data]:
data = self.data.to_numpy()
data = np.insert(data, 1, 1, axis=1) # adding bias
np.random.shuffle(data)
def as_ndarray(self) -> np.ndarray:
return self.data.to_numpy()
def get_index(self, column:str) -> int:
return self.data.columns.get_loc(column)
class PrincipalComponentAnalisys:
def __init__(self, data:np.ndarray) -> None:
self.data = data
def reduce(self, total:int=0, threshold:float=1) -> Self:
columns = self.data.shape[1]
if total > columns or total <= 0:
total = columns
if threshold <= 0 or threshold > 1:
threshold = 1
total = data.shape[0]
valid_cutoff = int(total * valid_frac)
test_cutoff = int(total * test_frac) + valid_cutoff
valid = data[:valid_cutoff]
test = data[valid_cutoff:test_cutoff]
learn = data[test_cutoff:]
l = []
for ds in [learn, test, valid]:
target = ds[:, 0] if self.target_type != TargetType.NoTarget else None
ds = ds[:, 1:]
if self.target_type == TargetType.MultiClassification:
target = target.astype(int)
uniques = np.unique(target).shape[0]
target = np.eye(uniques)[target]
l.append(Data(ds, target))
return l
if __name__ == "__main__":
df = Dataset("datasets\\regression\\automobile.csv", "symboling")
attributes_to_modify = ["fuel-system", "engine-type", "drive-wheels", "body-style", "make", "engine-location", "aspiration", "fuel-type", "num-of-cylinders", "num-of-doors"]
df.factorize(attributes_to_modify)
df.to_numbers(["normalized-losses", "bore", "stroke", "horsepower", "peak-rpm", "price"])
df.handle_na()
df.regularize(excepts=attributes_to_modify)
print(df.data.dtypes)
ds = Dataset("datasets\\classification\\frogs.csv", "Species", TargetType.MultiClassification)
ds.remove(["Family", "Genus", "RecordID"])
ds.factorize(["Species"])
np.random.seed(0)
learn, test, valid = ds.get_dataset()
print(learn)
print(test)
print(valid)