- init LikelyhoodWeighting
- added NetworkNode helper class
This commit is contained in:
2023-12-19 00:16:37 +01:00
parent 6392132560
commit 988328cd49
3 changed files with 143 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
package net.berack.upo.ai.problem3;
import java.security.SecureRandom;
import java.util.Objects;
import smile.Network;
public class LikelyhoodWeighting {
public final Network net;
public LikelyhoodWeighting(Network net) {
this.net = Objects.requireNonNull(net);
}
public void calculate(int totalRuns) {
totalRuns = Math.max(1, totalRuns);
var nodes = NetworkNode.buildSetFrom(net, totalRuns);
var rand = new SecureRandom();
var prob = new double[totalRuns];
var sum = 0.0d;
for(var run = 0; run < totalRuns; run++) {
prob[run] = 1;
for(var node: nodes) {
if(!node.isEvidence()) node.setSample(rand.nextDouble(), run);
else prob[run] *= node.getProbSampleEvidence(run);
}
sum += prob[run];
}
for(var node : nodes) if(!node.isEvidence()) {
var values = new double[node.outcomes.length];
for(var run = 0; run < totalRuns; run++)
values[node.samples[run]] += prob[run];
for(var i = 0; i < values.length; i++)
values[i] /= sum;
net.setPointValues(node.handle, values);
}
}
}

View File

@@ -0,0 +1,92 @@
package net.berack.upo.ai.problem3;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Set;
import smile.Network;
public class NetworkNode {
public static Set<NetworkNode> buildSetFrom(Network net, int totRounds) {
var nodes = new HashMap<Integer, NetworkNode>();
for(var handle : net.getAllNodes()) nodes.put(handle, new NetworkNode(net, handle));
var retSet = Set.copyOf(nodes.values());
for(var node : retSet) {
var parentsHandle = net.getParents(node.handle);
node.parents = new NetworkNode[parentsHandle.length];
for(var i = 0; i < parentsHandle.length; i++)
node.parents[i] = nodes.get(parentsHandle[i]);
if(!node.isEvidence()) {
node.samples = new int[totRounds];
Arrays.fill(node.samples, -1);
}
}
return retSet;
}
public static NetworkNode[] topologicalSort(Set<NetworkNode> nodes) {
throw new UnsupportedOperationException("TODO implement this function");
}
final int handle;
final String[] outcomes;
final double[] definition;
final int evidence;
NetworkNode[] parents;
int[] samples;
private NetworkNode(Network net, int handle) {
this.handle = handle;
this.definition = net.getNodeDefinition(handle);
this.outcomes = net.getOutcomeIds(handle);
this.evidence = net.isEvidence(handle)? net.getEvidence(handle) : -1;
}
public boolean isEvidence() {
return this.evidence > 0 ;
}
public double getProbSampleEvidence(int round) {
if(!this.isEvidence()) throw new IllegalArgumentException("Evidence");
var init = getStartingIndex(round);
return this.definition[init + this.evidence];
}
public void setSample(double rand, int round) {
var init = getStartingIndex(round);
var end = init + this.outcomes.length;
var prob = 0;
for(var i = init; i < end; i++) {
prob += this.definition[i];
if(rand <= prob) {
this.samples[round] = i - init;
break;
}
}
}
private int getStartingIndex(int round) {
var init = 0;
var tot = this.definition.length;
for(var p : this.parents) {
var pIndex = p.isEvidence()? p.evidence : p.samples[round];
if(pIndex == -1) throw new IllegalArgumentException("Parent");
tot /= p.outcomes.length;
init += tot * pIndex;
}
return init;
}
}

View File

@@ -49,6 +49,12 @@ public class SmileLib {
var net = new Network();
net.readFile(RESOURCE_PATH + "VentureBN.xdsl");
var nodes = net.getAllNodes();
for (var i = 0; i < nodes.length; i++) {
System.out.println(nodes[i] + " -> " + net.getNodeId(nodes[i]));
}
net.setEvidence("Forecast", "Moderate");
net.updateBeliefs();