From 988328cd4918faa0108d6b7f52d77e9067c9fa60 Mon Sep 17 00:00:00 2001 From: Berack96 Date: Tue, 19 Dec 2023 00:16:37 +0100 Subject: [PATCH] LW - init LikelyhoodWeighting - added NetworkNode helper class --- .../upo/ai/problem3/LikelyhoodWeighting.java | 45 +++++++++ .../berack/upo/ai/problem3/NetworkNode.java | 92 +++++++++++++++++++ .../net/berack/upo/ai/problem3/SmileLib.java | 6 ++ 3 files changed, 143 insertions(+) create mode 100644 src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java create mode 100644 src/main/java/net/berack/upo/ai/problem3/NetworkNode.java diff --git a/src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java b/src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java new file mode 100644 index 0000000..817cb75 --- /dev/null +++ b/src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java @@ -0,0 +1,45 @@ +package net.berack.upo.ai.problem3; + +import java.security.SecureRandom; +import java.util.Objects; +import smile.Network; + +public class LikelyhoodWeighting { + + public final Network net; + + public LikelyhoodWeighting(Network net) { + this.net = Objects.requireNonNull(net); + } + + public void calculate(int totalRuns) { + totalRuns = Math.max(1, totalRuns); + + var nodes = NetworkNode.buildSetFrom(net, totalRuns); + var rand = new SecureRandom(); + var prob = new double[totalRuns]; + var sum = 0.0d; + + for(var run = 0; run < totalRuns; run++) { + prob[run] = 1; + + for(var node: nodes) { + if(!node.isEvidence()) node.setSample(rand.nextDouble(), run); + else prob[run] *= node.getProbSampleEvidence(run); + } + + sum += prob[run]; + } + + for(var node : nodes) if(!node.isEvidence()) { + var values = new double[node.outcomes.length]; + + for(var run = 0; run < totalRuns; run++) + values[node.samples[run]] += prob[run]; + for(var i = 0; i < values.length; i++) + values[i] /= sum; + + net.setPointValues(node.handle, values); + } + } +} diff --git a/src/main/java/net/berack/upo/ai/problem3/NetworkNode.java b/src/main/java/net/berack/upo/ai/problem3/NetworkNode.java new file mode 100644 index 0000000..5130cfb --- /dev/null +++ b/src/main/java/net/berack/upo/ai/problem3/NetworkNode.java @@ -0,0 +1,92 @@ +package net.berack.upo.ai.problem3; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Set; + +import smile.Network; + +public class NetworkNode { + + public static Set buildSetFrom(Network net, int totRounds) { + var nodes = new HashMap(); + + for(var handle : net.getAllNodes()) nodes.put(handle, new NetworkNode(net, handle)); + + var retSet = Set.copyOf(nodes.values()); + for(var node : retSet) { + var parentsHandle = net.getParents(node.handle); + node.parents = new NetworkNode[parentsHandle.length]; + + for(var i = 0; i < parentsHandle.length; i++) + node.parents[i] = nodes.get(parentsHandle[i]); + + if(!node.isEvidence()) { + node.samples = new int[totRounds]; + Arrays.fill(node.samples, -1); + } + } + + return retSet; + } + + public static NetworkNode[] topologicalSort(Set nodes) { + throw new UnsupportedOperationException("TODO implement this function"); + } + + final int handle; + final String[] outcomes; + final double[] definition; + final int evidence; + + NetworkNode[] parents; + int[] samples; + + private NetworkNode(Network net, int handle) { + this.handle = handle; + this.definition = net.getNodeDefinition(handle); + this.outcomes = net.getOutcomeIds(handle); + this.evidence = net.isEvidence(handle)? net.getEvidence(handle) : -1; + } + + public boolean isEvidence() { + return this.evidence > 0 ; + } + + public double getProbSampleEvidence(int round) { + if(!this.isEvidence()) throw new IllegalArgumentException("Evidence"); + + var init = getStartingIndex(round); + return this.definition[init + this.evidence]; + } + + public void setSample(double rand, int round) { + var init = getStartingIndex(round); + var end = init + this.outcomes.length; + var prob = 0; + + for(var i = init; i < end; i++) { + prob += this.definition[i]; + + if(rand <= prob) { + this.samples[round] = i - init; + break; + } + } + } + + private int getStartingIndex(int round) { + var init = 0; + var tot = this.definition.length; + + for(var p : this.parents) { + var pIndex = p.isEvidence()? p.evidence : p.samples[round]; + if(pIndex == -1) throw new IllegalArgumentException("Parent"); + + tot /= p.outcomes.length; + init += tot * pIndex; + } + + return init; + } + } \ No newline at end of file diff --git a/src/main/java/net/berack/upo/ai/problem3/SmileLib.java b/src/main/java/net/berack/upo/ai/problem3/SmileLib.java index 30134d6..af8b279 100644 --- a/src/main/java/net/berack/upo/ai/problem3/SmileLib.java +++ b/src/main/java/net/berack/upo/ai/problem3/SmileLib.java @@ -49,6 +49,12 @@ public class SmileLib { var net = new Network(); net.readFile(RESOURCE_PATH + "VentureBN.xdsl"); + + var nodes = net.getAllNodes(); + for (var i = 0; i < nodes.length; i++) { + System.out.println(nodes[i] + " -> " + net.getNodeId(nodes[i])); + } + net.setEvidence("Forecast", "Moderate"); net.updateBeliefs();