LikelyHoodWeighting

- added docs
- added simple tests
- added more net for testing
- fixed lw errors from int to float casting
This commit is contained in:
2023-12-30 00:19:47 +01:00
parent e2fa69e2d5
commit 4dc02fcf31
12 changed files with 872 additions and 60 deletions

View File

@@ -1,34 +1,66 @@
package net.berack.upo.ai.problem3;
import java.security.SecureRandom;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import smile.Network;
/**
* Calcolo dei valori tramite l'algoritmo del Likelyhood Weighting
* @author Berack
*/
public class LikelyhoodWeighting {
public final Network net;
private Map<Integer, double[]> values = new HashMap<>();
/**
* Inizializza un nuovo oggetto che calcolerà i valori per la rete inserita
* @param net la rete a cui calcolare i valori
*/
public LikelyhoodWeighting(Network net) {
this.net = Objects.requireNonNull(net);
}
public void calculate(int totalRuns) {
/**
* Recupera i valori del nodo dopo averli calcolati
* Nel caso in cui non si abbia ancora fatto {@link #updateNetwork(int)} allora restituirà
* una eccezione di tiop UnsupportedOperationException
* @param node il nodo da vedere
* @return l'array di valori da restituire
*/
public double[] getNodeValue(int node) {
if(values.size() == 0) throw new UnsupportedOperationException("You should run first updateNetwork method");
return values.get(node);
}
/**
* Calcola i valori possibili per la rete.
* Per poterli vedere utilizzare il metodo {@link #getNodeValue(int)}
* @param totalRuns
*/
public void updateNetwork(int totalRuns) {
totalRuns = Math.max(1, totalRuns);
var nodes = NetworkNode.buildSetFrom(net, totalRuns);
var nodes = SmileLib.buildListFrom(net);
var rand = new SecureRandom();
var prob = new double[totalRuns];
var sum = 0.0d;
for(var node : nodes)
node.samples = new int[totalRuns];
for(var run = 0; run < totalRuns; run++) {
prob[run] = 1;
var probRun = 1.0d;
for(var node: nodes) {
if(!node.isEvidence()) node.setSample(rand.nextDouble(), run);
else prob[run] *= node.getProbSampleEvidence(run);
else probRun *= node.getProbSampleEvidence(run);
}
sum += prob[run];
prob[run] = probRun;
sum += probRun;
}
for(var node : nodes) if(!node.isEvidence()) {
@@ -39,7 +71,7 @@ public class LikelyhoodWeighting {
for(var i = 0; i < values.length; i++)
values[i] /= sum;
net.setPointValues(node.handle, values);
this.values.put(node.handle, values);
}
}
}

View File

@@ -1,58 +1,59 @@
package net.berack.upo.ai.problem3;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Set;
import smile.Network;
/**
* Una classe di appoggio per i nodi di un network.
* Questa classe contiene anche un array di samples in modo da facilitare il
* calcolo di valori.
*
* @see SmileLib#buildListFrom(NetWork)
* @author Berack
*/
public class NetworkNode {
public static Set<NetworkNode> buildSetFrom(Network net, int totRounds) {
var nodes = new HashMap<Integer, NetworkNode>();
for(var handle : net.getAllNodes()) nodes.put(handle, new NetworkNode(net, handle));
var retSet = Set.copyOf(nodes.values());
for(var node : retSet) {
var parentsHandle = net.getParents(node.handle);
node.parents = new NetworkNode[parentsHandle.length];
for(var i = 0; i < parentsHandle.length; i++)
node.parents[i] = nodes.get(parentsHandle[i]);
if(!node.isEvidence()) {
node.samples = new int[totRounds];
Arrays.fill(node.samples, -1);
}
}
return retSet;
}
public static NetworkNode[] topologicalSort(Set<NetworkNode> nodes) {
throw new UnsupportedOperationException("TODO implement this function");
}
final int handle;
final String[] outcomes;
final double[] definition;
final int evidence;
final Network net;
NetworkNode[] parents;
int[] samples;
public int[] samples;
private NetworkNode(Network net, int handle) {
/**
* Questo costruttore crea un nodo e gli assegna i valori essenziali.
* @apiNote Non usare questo costruttore se non si sa che si sa quello che si sta facendo
* @param net la rete
* @param handle l'handle del nodo
*/
NetworkNode(Network net, int handle) {
this.handle = handle;
this.net = net;
this.definition = net.getNodeDefinition(handle);
this.outcomes = net.getOutcomeIds(handle);
this.evidence = net.isEvidence(handle)? net.getEvidence(handle) : -1;
}
/**
* Indica se il nodo è evidenza o meno
* @return vero se lo è
*/
public boolean isEvidence() {
return this.evidence > 0 ;
}
/**
* Per utilizzare questo metodo il nodo deve essere una evidenza.
* Dato un roud di sample permette di ricevere il valore della
* probabilità che, dati i sample dei genitori, il nodo abbia
* il valore dell'evidenza impostata.
*
* @param round il numero del round che si stà controllando
* @return il valore della probabilità della evidenza
*/
public double getProbSampleEvidence(int round) {
if(!this.isEvidence()) throw new IllegalArgumentException("Evidence");
@@ -60,28 +61,44 @@ public class NetworkNode {
return this.definition[init + this.evidence];
}
/**
* Mette un sample al nodo nel round selezionato.
* Il valore di rand deve essere un numero casuale tra 0 e 1 ed
* esso permetterà di impostare un valore in base ai valori dei genitori.
*
* @param rand un valore casuale tra 0 e 1
* @param round il numero del round
*/
public void setSample(double rand, int round) {
var init = getStartingIndex(round);
var end = init + this.outcomes.length;
var prob = 0;
var prob = 0.0d;
for(var i = init; i < end; i++) {
prob += this.definition[i];
if(rand <= prob) {
if(prob >= rand) {
this.samples[round] = i - init;
break;
}
}
}
/**
* Dato un round permette di ricavare l'indice di partenza della CPT.
* Questo metodo serve perchè i genitori del nodo nel sample hanno
* dei valori e io devo generarli in accordo con la CPT di questo nodo.
*
* @param round il roundo corrente
* @return l'indice iniziale per gli output del nodo in base ai valori dei genitori
*/
private int getStartingIndex(int round) {
var init = 0;
var tot = this.definition.length;
for(var p : this.parents) {
var pIndex = p.isEvidence()? p.evidence : p.samples[round];
if(pIndex == -1) throw new IllegalArgumentException("Parent");
if(pIndex < 0) throw new IllegalArgumentException("Parent"); // in theory impossible since Topological sorted
tot /= p.outcomes.length;
init += tot * pIndex;
@@ -89,4 +106,16 @@ public class NetworkNode {
return init;
}
@Override
public boolean equals(Object obj) {
if(!obj.getClass().isInstance(this)) return false;
var other = (NetworkNode) obj;
if(this.handle != other.handle) return false;
if(this.evidence != other.evidence) return false;
if(!Arrays.equals(this.definition, other.definition)) return false;
return true;
}
}

View File

@@ -1,13 +1,24 @@
package net.berack.upo.ai.problem3;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import smile.Network;
/**
* Classe che permette l'utilizzo della libreria SMILE di BAYESFUSION.
* La classe carica staticamente la libreria.dll creando la proprietà di
* sistema jsmile.native.library includendo la resource path di jsmile.
* In questo modo per utilizzare SMILE basta chiamare un metodo di questa classe
* per far si che la chiave di attivazione venga correttamente controllata.
*
* @apiNote Scadenza chiave 2024-06-16
* @author Berack
*/
public class SmileLib {
public static final String RESOURCE_PATH;
static {
var loader = SmileLib.class.getClassLoader();
var wrongPath = loader.getResource("").getFile();
@@ -39,30 +50,51 @@ public class SmileLib {
);
}
/**
* Crea un Network dal file indicato
* Il file deve essere una risorsa del jar o un file esterno
*
* @param file il file da cercare
* @return il network creato
*/
public static Network getNetworkFrom(String file) {
var net = new Network();
net.readFile(RESOURCE_PATH + file);
try {
net.readFile(RESOURCE_PATH + file);
} catch (smile.SMILEException e) {
net.readFile(file);
}
return net;
}
public static void main(String[] args) throws Exception {
var net = new Network();
/**
* Crea una lista di nodi dal network indicato.
* I nodi usati sono un po' più comodi rispetto al network.
* La lista è ordinata in modo che il nodo 'k' sia un discendente
* dei nodi '0...k-1' e non di 'k+1...n'
*
* @param net il network da cui prendere i dati
* @return una lista ordinata di nodi
*/
public static List<NetworkNode> buildListFrom(Network net) {
var nodes = new HashMap<Integer, NetworkNode>();
var list = new ArrayList<NetworkNode>();
net.readFile(RESOURCE_PATH + "VentureBN.xdsl");
for(var handle : net.getAllNodes()) {
var node = new NetworkNode(net, handle);
list.add(node);
nodes.put(handle, node);
}
var nodes = net.getAllNodes();
for (var i = 0; i < nodes.length; i++) {
System.out.println(nodes[i] + " -> " + net.getNodeId(nodes[i]));
for(var node : nodes.values()) {
var parentsHandle = net.getParents(node.handle);
node.parents = new NetworkNode[parentsHandle.length];
for(var i = 0; i < parentsHandle.length; i++)
node.parents[i] = nodes.get(parentsHandle[i]);
}
return list;
}
net.setEvidence("Forecast", "Moderate");
net.updateBeliefs();
var beliefs = net.getNodeValue("Success");
for (var i = 0; i < beliefs.length; i++) {
System.out.println(net.getOutcomeId("Success", i) + " = " + beliefs[i]);
}
net.close();
}
}

View File

@@ -1,41 +0,0 @@
<?xml version="1.0" encoding="ISO-8859-1"?>
<smile version="1.0" id="VentureBN" numsamples="1000">
<nodes>
<cpt id="Success">
<state id="Success" />
<state id="Failure" />
<probabilities>0.2 0.8</probabilities>
</cpt>
<cpt id="Forecast">
<state id="Good" />
<state id="Moderate" />
<state id="Poor" />
<parents>Success</parents>
<probabilities>
0.4 0.4 0.2 0.1 0.3 0.6
</probabilities>
</cpt>
</nodes>
<extensions>
<genie version="1.0" app="GeNIe 2.1.1104.2"
name="VentureBN"
faultnameformat="nodestate">
<node id="Success">
<name>Success of the venture</name>
<interior color="e5f6f7" />
<outline color="0000bb" />
<font color="000000" name="Arial" size="8" />
<position>54 11 138 62</position>
</node>
<node id="Forecast">
<name>Expert forecast</name>
<interior color="e5f6f7" />
<outline color="0000bb" />
<font color="000000" name="Arial" size="8" />
<position>63 105 130 155</position>
</node>
</genie>
</extensions>
</smile>