- improved tests for LW
- fixed LW bugs
- refactored NetworkNode for better usability
- removed net not used in tests
This commit is contained in:
2024-01-02 19:08:41 +01:00
parent 4dc02fcf31
commit 42a54947a5
8 changed files with 147 additions and 616 deletions

View File

@@ -13,7 +13,7 @@ import smile.Network;
public class LikelyhoodWeighting {
public final Network net;
private Map<Integer, double[]> values = new HashMap<>();
private final Map<Integer, NetworkNode> nodes = new HashMap<>();
/**
* Inizializza un nuovo oggetto che calcolerà i valori per la rete inserita
@@ -31,8 +31,8 @@ public class LikelyhoodWeighting {
* @return l'array di valori da restituire
*/
public double[] getNodeValue(int node) {
if(values.size() == 0) throw new UnsupportedOperationException("You should run first updateNetwork method");
return values.get(node);
if(nodes.size() == 0) throw new UnsupportedOperationException("You should run first updateNetwork method");
return nodes.get(node).values;
}
/**
@@ -43,35 +43,35 @@ public class LikelyhoodWeighting {
public void updateNetwork(int totalRuns) {
totalRuns = Math.max(1, totalRuns);
var nodes = SmileLib.buildListFrom(net);
var list = SmileLib.buildListFrom(net);
var rand = new SecureRandom();
var prob = new double[totalRuns];
var sum = 0.0d;
for(var node : nodes)
node.samples = new int[totalRuns];
for(var run = 0; run < totalRuns; run++) {
var probRun = 1.0d;
for(var node: nodes) {
if(!node.isEvidence()) node.setSample(rand.nextDouble(), run);
else probRun *= node.getProbSampleEvidence(run);
for(var node: list) {
if(node.isEvidence()) probRun *= node.getProbSampleEvidence();
else node.setSample(rand.nextDouble());
}
for(var node: list) {
if(!node.isEvidence()) {
node.values[node.sample] += probRun;
node.sample = -1;
}
}
prob[run] = probRun;
sum += probRun;
}
for(var node : nodes) if(!node.isEvidence()) {
var values = new double[node.outcomes.length];
this.nodes.clear();
for(var node : list) {
this.nodes.put(node.handle, node);
for(var run = 0; run < totalRuns; run++)
values[node.samples[run]] += prob[run];
for(var i = 0; i < values.length; i++)
values[i] /= sum;
this.values.put(node.handle, values);
if(!node.isEvidence())
for(var i = 0; i < node.values.length; i++)
node.values[i] /= sum;
}
}
}

View File

@@ -1,6 +1,8 @@
package net.berack.upo.ai.problem3;
import java.util.Arrays;
import java.util.Map;
import smile.Network;
/**
@@ -13,109 +15,121 @@ import smile.Network;
*/
public class NetworkNode {
final int handle;
final String[] outcomes;
final double[] definition;
final int evidence;
final Network net;
final int handle;
final String[] outcomes;
final double[] definition;
final int evidence;
final Network net;
final int type;
final NetworkNode[] parents;
NetworkNode[] parents;
public int[] samples;
public int sample = -1;
public final double[] values;
/**
* Questo costruttore crea un nodo e gli assegna i valori essenziali.
* @apiNote Non usare questo costruttore se non si sa che si sa quello che si sta facendo
* @param net la rete
* @param handle l'handle del nodo
*/
NetworkNode(Network net, int handle) {
this.handle = handle;
this.net = net;
/**
* Questo costruttore crea un nodo e gli assegna i valori essenziali.
* @apiNote Non usare questo costruttore se non si sa che si sa quello che si sta facendo
* @param net la rete
* @param handle l'handle del nodo
* @param nodes i nodi creati prima di questo
*/
NetworkNode(Network net, int handle, Map<Integer, NetworkNode> nodes) {
this.handle = handle;
this.type = net.getNodeType(handle);
this.net = net;
this.definition = net.getNodeDefinition(handle);
this.outcomes = net.getOutcomeIds(handle);
this.evidence = net.isEvidence(handle)? net.getEvidence(handle) : -1;
}
this.outcomes = net.getOutcomeIds(handle);
this.values = new double[this.outcomes.length];
this.evidence = net.isEvidence(handle)? net.getEvidence(handle) : -1;
if(this.isEvidence()) this.values[this.evidence] = 1.0d;
/**
* Indica se il nodo è evidenza o meno
* @return vero se lo è
*/
public boolean isEvidence() {
return this.evidence > 0 ;
}
var parentsHandle = net.getParents(handle);
this.parents = new NetworkNode[parentsHandle.length];
for(var i = 0; i < parentsHandle.length; i++)
this.parents[i] = nodes.get(parentsHandle[i]);
/**
* Per utilizzare questo metodo il nodo deve essere una evidenza.
* Dato un roud di sample permette di ricevere il valore della
* probabilità che, dati i sample dei genitori, il nodo abbia
* il valore dell'evidenza impostata.
*
* @param round il numero del round che si stà controllando
* @return il valore della probabilità della evidenza
*/
public double getProbSampleEvidence(int round) {
if(!this.isEvidence()) throw new IllegalArgumentException("Evidence");
this.definition = switch(this.type) {
case Network.NodeType.CPT -> net.getNodeDefinition(handle);
case Network.NodeType.NOISY_MAX -> net.getNoisyExpandedDefinition(handle);
default -> throw new IllegalArgumentException("Network with node type not supporrted -> " + this.type);
};
}
var init = getStartingIndex(round);
return this.definition[init + this.evidence];
}
/**
* Indica se il nodo è evidenza o meno
* @return vero se lo è
*/
public boolean isEvidence() {
return this.evidence >= 0 ;
}
/**
* Mette un sample al nodo nel round selezionato.
* Il valore di rand deve essere un numero casuale tra 0 e 1 ed
* esso permetterà di impostare un valore in base ai valori dei genitori.
*
* @param rand un valore casuale tra 0 e 1
* @param round il numero del round
*/
public void setSample(double rand, int round) {
var init = getStartingIndex(round);
var end = init + this.outcomes.length;
var prob = 0.0d;
/**
* Per utilizzare questo metodo il nodo deve essere una evidenza.
* Permette di ricevere il valore della probabilità che, dati i sample dei genitori,
* il nodo abbia il valore dell'evidenza impostata.
*
* @return il valore della probabilità della evidenza
*/
public double getProbSampleEvidence() {
if(!this.isEvidence()) throw new IllegalArgumentException("Evidence");
for(var i = init; i < end; i++) {
prob += this.definition[i];
var init = getStartingIndex();
return this.definition[init + this.evidence];
}
if(prob >= rand) {
this.samples[round] = i - init;
break;
}
/**
* Mette un sample al nodo nel round selezionato.
* Il valore di rand deve essere un numero casuale tra 0 e 1 ed
* esso permetterà di impostare un valore in base ai valori dei sample dei genitori.
*
* @param rand un valore casuale tra 0 e 1
*/
public void setSample(double rand) {
var init = getStartingIndex();
var end = init + this.values.length;
var prob = 0.0d;
for(var i = init; i < end; i++) {
prob += this.definition[i];
if(prob >= rand) {
this.sample = i - init;
break;
}
}
}
/**
* Dato un round permette di ricavare l'indice di partenza della CPT.
* Questo metodo serve perchè i genitori del nodo nel sample hanno
* dei valori e io devo generarli in accordo con la CPT di questo nodo.
*
* @param round il roundo corrente
* @return l'indice iniziale per gli output del nodo in base ai valori dei genitori
*/
private int getStartingIndex(int round) {
var init = 0;
var tot = this.definition.length;
/**
* Permette di ricavare l'indice di partenza della CPT.
* Questo metodo serve perchè i genitori del nodo nel sample hanno
* dei valori e io devo generarli in accordo con la CPT di questo nodo.
*
* @return l'indice iniziale per gli output del nodo in base ai valori dei genitori
*/
private int getStartingIndex() {
var init = 0;
var tot = this.definition.length;
for(var p : this.parents) {
var pIndex = p.isEvidence()? p.evidence : p.samples[round];
if(pIndex < 0) throw new IllegalArgumentException("Parent"); // in theory impossible since Topological sorted
for(var p : this.parents) {
var pIndex = p.isEvidence()? p.evidence : p.sample;
if(pIndex < 0) throw new IllegalArgumentException("Parent"); // in theory impossible since Topological sorted
tot /= p.outcomes.length;
init += tot * pIndex;
}
return init;
tot /= p.outcomes.length;
init += tot * pIndex;
}
@Override
public boolean equals(Object obj) {
if(!obj.getClass().isInstance(this)) return false;
return init;
}
var other = (NetworkNode) obj;
if(this.handle != other.handle) return false;
if(this.evidence != other.evidence) return false;
if(!Arrays.equals(this.definition, other.definition)) return false;
@Override
public boolean equals(Object obj) {
if(!obj.getClass().isInstance(this)) return false;
return true;
}
}
var other = (NetworkNode) obj;
if(this.handle != other.handle) return false;
if(this.evidence != other.evidence) return false;
if(!Arrays.equals(this.definition, other.definition)) return false;
return true;
}
}

View File

@@ -82,19 +82,11 @@ public class SmileLib {
var list = new ArrayList<NetworkNode>();
for(var handle : net.getAllNodes()) {
var node = new NetworkNode(net, handle);
var node = new NetworkNode(net, handle, nodes);
list.add(node);
nodes.put(handle, node);
}
for(var node : nodes.values()) {
var parentsHandle = net.getParents(node.handle);
node.parents = new NetworkNode[parentsHandle.length];
for(var i = 0; i < parentsHandle.length; i++)
node.parents[i] = nodes.get(parentsHandle[i]);
}
return list;
}
}