From 42a54947a5ea476dd634f2af3e1420fa087a03f2 Mon Sep 17 00:00:00 2001 From: Berack96 Date: Tue, 2 Jan 2024 19:08:41 +0100 Subject: [PATCH] LWTest - improved tests for LW - fixed LW bugs - refactored NetworkNode for better usability - removed net not used in tests --- .../upo/ai/problem3/LikelyhoodWeighting.java | 40 ++-- .../berack/upo/ai/problem3/NetworkNode.java | 190 ++++++++++-------- .../net/berack/upo/ai/problem3/SmileLib.java | 10 +- .../net/berack/upo/ai/problem3/LWTest.java | 36 ++-- src/test/resources/Airport.xdsl | 146 -------------- src/test/resources/AppleTree.xdsl | 92 --------- src/test/resources/Micromorti.xdsl | 89 -------- src/test/resources/lucas96simp.xdsl | 160 --------------- 8 files changed, 147 insertions(+), 616 deletions(-) delete mode 100644 src/test/resources/Airport.xdsl delete mode 100644 src/test/resources/AppleTree.xdsl delete mode 100644 src/test/resources/Micromorti.xdsl delete mode 100644 src/test/resources/lucas96simp.xdsl diff --git a/src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java b/src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java index 53f76e4..46eb113 100644 --- a/src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java +++ b/src/main/java/net/berack/upo/ai/problem3/LikelyhoodWeighting.java @@ -13,7 +13,7 @@ import smile.Network; public class LikelyhoodWeighting { public final Network net; - private Map values = new HashMap<>(); + private final Map nodes = new HashMap<>(); /** * Inizializza un nuovo oggetto che calcolerà i valori per la rete inserita @@ -31,8 +31,8 @@ public class LikelyhoodWeighting { * @return l'array di valori da restituire */ public double[] getNodeValue(int node) { - if(values.size() == 0) throw new UnsupportedOperationException("You should run first updateNetwork method"); - return values.get(node); + if(nodes.size() == 0) throw new UnsupportedOperationException("You should run first updateNetwork method"); + return nodes.get(node).values; } /** @@ -43,35 +43,35 @@ public class LikelyhoodWeighting { public void updateNetwork(int totalRuns) { totalRuns = Math.max(1, totalRuns); - var nodes = SmileLib.buildListFrom(net); + var list = SmileLib.buildListFrom(net); var rand = new SecureRandom(); - var prob = new double[totalRuns]; var sum = 0.0d; - for(var node : nodes) - node.samples = new int[totalRuns]; - for(var run = 0; run < totalRuns; run++) { var probRun = 1.0d; - for(var node: nodes) { - if(!node.isEvidence()) node.setSample(rand.nextDouble(), run); - else probRun *= node.getProbSampleEvidence(run); + for(var node: list) { + if(node.isEvidence()) probRun *= node.getProbSampleEvidence(); + else node.setSample(rand.nextDouble()); + } + + for(var node: list) { + if(!node.isEvidence()) { + node.values[node.sample] += probRun; + node.sample = -1; + } } - prob[run] = probRun; sum += probRun; } - for(var node : nodes) if(!node.isEvidence()) { - var values = new double[node.outcomes.length]; + this.nodes.clear(); + for(var node : list) { + this.nodes.put(node.handle, node); - for(var run = 0; run < totalRuns; run++) - values[node.samples[run]] += prob[run]; - for(var i = 0; i < values.length; i++) - values[i] /= sum; - - this.values.put(node.handle, values); + if(!node.isEvidence()) + for(var i = 0; i < node.values.length; i++) + node.values[i] /= sum; } } } diff --git a/src/main/java/net/berack/upo/ai/problem3/NetworkNode.java b/src/main/java/net/berack/upo/ai/problem3/NetworkNode.java index 1af7814..59599a8 100644 --- a/src/main/java/net/berack/upo/ai/problem3/NetworkNode.java +++ b/src/main/java/net/berack/upo/ai/problem3/NetworkNode.java @@ -1,6 +1,8 @@ package net.berack.upo.ai.problem3; import java.util.Arrays; +import java.util.Map; + import smile.Network; /** @@ -13,109 +15,121 @@ import smile.Network; */ public class NetworkNode { - final int handle; - final String[] outcomes; - final double[] definition; - final int evidence; - final Network net; + final int handle; + final String[] outcomes; + final double[] definition; + final int evidence; + final Network net; + final int type; + final NetworkNode[] parents; - NetworkNode[] parents; - public int[] samples; + public int sample = -1; + public final double[] values; - /** - * Questo costruttore crea un nodo e gli assegna i valori essenziali. - * @apiNote Non usare questo costruttore se non si sa che si sa quello che si sta facendo - * @param net la rete - * @param handle l'handle del nodo - */ - NetworkNode(Network net, int handle) { - this.handle = handle; - this.net = net; + /** + * Questo costruttore crea un nodo e gli assegna i valori essenziali. + * @apiNote Non usare questo costruttore se non si sa che si sa quello che si sta facendo + * @param net la rete + * @param handle l'handle del nodo + * @param nodes i nodi creati prima di questo + */ + NetworkNode(Network net, int handle, Map nodes) { + this.handle = handle; + this.type = net.getNodeType(handle); + this.net = net; - this.definition = net.getNodeDefinition(handle); - this.outcomes = net.getOutcomeIds(handle); - this.evidence = net.isEvidence(handle)? net.getEvidence(handle) : -1; - } + this.outcomes = net.getOutcomeIds(handle); + this.values = new double[this.outcomes.length]; + this.evidence = net.isEvidence(handle)? net.getEvidence(handle) : -1; + if(this.isEvidence()) this.values[this.evidence] = 1.0d; - /** - * Indica se il nodo è evidenza o meno - * @return vero se lo è - */ - public boolean isEvidence() { - return this.evidence > 0 ; - } + var parentsHandle = net.getParents(handle); + this.parents = new NetworkNode[parentsHandle.length]; + for(var i = 0; i < parentsHandle.length; i++) + this.parents[i] = nodes.get(parentsHandle[i]); - /** - * Per utilizzare questo metodo il nodo deve essere una evidenza. - * Dato un roud di sample permette di ricevere il valore della - * probabilità che, dati i sample dei genitori, il nodo abbia - * il valore dell'evidenza impostata. - * - * @param round il numero del round che si stà controllando - * @return il valore della probabilità della evidenza - */ - public double getProbSampleEvidence(int round) { - if(!this.isEvidence()) throw new IllegalArgumentException("Evidence"); + this.definition = switch(this.type) { + case Network.NodeType.CPT -> net.getNodeDefinition(handle); + case Network.NodeType.NOISY_MAX -> net.getNoisyExpandedDefinition(handle); + default -> throw new IllegalArgumentException("Network with node type not supporrted -> " + this.type); + }; + } - var init = getStartingIndex(round); - return this.definition[init + this.evidence]; - } + /** + * Indica se il nodo è evidenza o meno + * @return vero se lo è + */ + public boolean isEvidence() { + return this.evidence >= 0 ; + } - /** - * Mette un sample al nodo nel round selezionato. - * Il valore di rand deve essere un numero casuale tra 0 e 1 ed - * esso permetterà di impostare un valore in base ai valori dei genitori. - * - * @param rand un valore casuale tra 0 e 1 - * @param round il numero del round - */ - public void setSample(double rand, int round) { - var init = getStartingIndex(round); - var end = init + this.outcomes.length; - var prob = 0.0d; + /** + * Per utilizzare questo metodo il nodo deve essere una evidenza. + * Permette di ricevere il valore della probabilità che, dati i sample dei genitori, + * il nodo abbia il valore dell'evidenza impostata. + * + * @return il valore della probabilità della evidenza + */ + public double getProbSampleEvidence() { + if(!this.isEvidence()) throw new IllegalArgumentException("Evidence"); - for(var i = init; i < end; i++) { - prob += this.definition[i]; + var init = getStartingIndex(); + return this.definition[init + this.evidence]; + } - if(prob >= rand) { - this.samples[round] = i - init; - break; - } + /** + * Mette un sample al nodo nel round selezionato. + * Il valore di rand deve essere un numero casuale tra 0 e 1 ed + * esso permetterà di impostare un valore in base ai valori dei sample dei genitori. + * + * @param rand un valore casuale tra 0 e 1 + */ + public void setSample(double rand) { + var init = getStartingIndex(); + var end = init + this.values.length; + var prob = 0.0d; + + for(var i = init; i < end; i++) { + prob += this.definition[i]; + + if(prob >= rand) { + this.sample = i - init; + break; } } + } - /** - * Dato un round permette di ricavare l'indice di partenza della CPT. - * Questo metodo serve perchè i genitori del nodo nel sample hanno - * dei valori e io devo generarli in accordo con la CPT di questo nodo. - * - * @param round il roundo corrente - * @return l'indice iniziale per gli output del nodo in base ai valori dei genitori - */ - private int getStartingIndex(int round) { - var init = 0; - var tot = this.definition.length; + /** + * Permette di ricavare l'indice di partenza della CPT. + * Questo metodo serve perchè i genitori del nodo nel sample hanno + * dei valori e io devo generarli in accordo con la CPT di questo nodo. + * + * @return l'indice iniziale per gli output del nodo in base ai valori dei genitori + */ + private int getStartingIndex() { + var init = 0; + var tot = this.definition.length; - for(var p : this.parents) { - var pIndex = p.isEvidence()? p.evidence : p.samples[round]; - if(pIndex < 0) throw new IllegalArgumentException("Parent"); // in theory impossible since Topological sorted + for(var p : this.parents) { + var pIndex = p.isEvidence()? p.evidence : p.sample; + if(pIndex < 0) throw new IllegalArgumentException("Parent"); // in theory impossible since Topological sorted - tot /= p.outcomes.length; - init += tot * pIndex; - } - - return init; + tot /= p.outcomes.length; + init += tot * pIndex; } - @Override - public boolean equals(Object obj) { - if(!obj.getClass().isInstance(this)) return false; + return init; + } - var other = (NetworkNode) obj; - if(this.handle != other.handle) return false; - if(this.evidence != other.evidence) return false; - if(!Arrays.equals(this.definition, other.definition)) return false; + @Override + public boolean equals(Object obj) { + if(!obj.getClass().isInstance(this)) return false; - return true; - } - } \ No newline at end of file + var other = (NetworkNode) obj; + if(this.handle != other.handle) return false; + if(this.evidence != other.evidence) return false; + if(!Arrays.equals(this.definition, other.definition)) return false; + + return true; + } +} \ No newline at end of file diff --git a/src/main/java/net/berack/upo/ai/problem3/SmileLib.java b/src/main/java/net/berack/upo/ai/problem3/SmileLib.java index a07caab..aa6f4f8 100644 --- a/src/main/java/net/berack/upo/ai/problem3/SmileLib.java +++ b/src/main/java/net/berack/upo/ai/problem3/SmileLib.java @@ -82,19 +82,11 @@ public class SmileLib { var list = new ArrayList(); for(var handle : net.getAllNodes()) { - var node = new NetworkNode(net, handle); + var node = new NetworkNode(net, handle, nodes); list.add(node); nodes.put(handle, node); } - for(var node : nodes.values()) { - var parentsHandle = net.getParents(node.handle); - node.parents = new NetworkNode[parentsHandle.length]; - - for(var i = 0; i < parentsHandle.length; i++) - node.parents[i] = nodes.get(parentsHandle[i]); - } - return list; } } diff --git a/src/test/java/net/berack/upo/ai/problem3/LWTest.java b/src/test/java/net/berack/upo/ai/problem3/LWTest.java index a3fc876..7192dea 100644 --- a/src/test/java/net/berack/upo/ai/problem3/LWTest.java +++ b/src/test/java/net/berack/upo/ai/problem3/LWTest.java @@ -2,21 +2,19 @@ package net.berack.upo.ai.problem3; import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import java.util.Arrays; - import org.junit.jupiter.api.Test; +import smile.Network; + public class LWTest { + // 3% difference max (it is a lot but is fair) + public static final float DELTA = 0.05f; + @Test public void testSmile() { var net = SmileLib.getNetworkFrom("VentureBN.xdsl"); - var nodes = net.getAllNodes(); - for (var i = 0; i < nodes.length; i++) { - System.out.println(nodes[i] + " -> " + net.getNodeId(nodes[i])); - } - net.setEvidence("Forecast", "Moderate"); net.updateBeliefs(); @@ -29,18 +27,32 @@ public class LWTest { @Test public void testSimpleNetwork() { - var net = SmileLib.getNetworkFrom("VentureBN.xdsl"); - net.updateBeliefs(); + checkNodesValues(SmileLib.getNetworkFrom("Malaria.xdsl")); + checkNodesValues(SmileLib.getNetworkFrom("VentureBN.xdsl")); + } + @Test + public void testEvidence() { + var net = SmileLib.getNetworkFrom("WetGrass.xdsl"); + checkNodesValues(net); + + net.setEvidence("Sprinkler", "On"); + checkNodesValues(net); + + net.setEvidence("Wet_grass", "Bagnata"); + checkNodesValues(net); + } + + private void checkNodesValues(Network net) { var lw = new LikelyhoodWeighting(net); + + net.updateBeliefs(); lw.updateNetwork(1000); for(var node : net.getAllNodes()) { var arr1 = net.getNodeValue(node); var arr2 = lw.getNodeValue(node); - - System.out.println(Arrays.toString(arr1) + " " + Arrays.toString(arr2)); - assertArrayEquals(arr1, arr2, 0.05); // 5% difference max (it is a lot but is fair) + assertArrayEquals(arr1, arr2, DELTA); } } } diff --git a/src/test/resources/Airport.xdsl b/src/test/resources/Airport.xdsl deleted file mode 100644 index 2c33554..0000000 --- a/src/test/resources/Airport.xdsl +++ /dev/null @@ -1,146 +0,0 @@ - - - - - - - - - - - - - 0.3 0.7 - - - - - 0.2 0.8 - - - - - 0.4 0.6 - - - - - AirTraffic AirportSite - 0.9 0.09999999999999998 0.7 0.3 0.7 0.3 0.8 0.2 0.6 0.4 0.4 0.6 - - - - - AirTraffic AirportSite - 0.9 0.09999999999999998 0.8 0.2 0.7 0.3 0.7 0.3 0.4 0.6 0.2 0.8 - - - - - Litigation Construction AirportSite - 0.9 0.09999999999999998 0.8 0.2 0.75 0.25 0.3 0.7 0.25 0.75 0.2 0.8 0.4 0.6 0.5 0.5 0.45 0.55 0.3 0.7 0.2 0.8 0.09999999999999998 0.9 - - - Deaths - 1 0 - - - Noise - 1 0 - - - Cost - 1 0 - - - UC UD UN - 2 10 1 - - - - - - AirportSite - - - - 440 289 503 315 - - - AirTraffic - - - - 375 62 436 100 - - - Litigation - - - - 375 138 432 173 - - - Construction - - - - 365 211 440 257 - - - Deaths - - - - 558 62 609 93 - - - Noise - - - - 564 138 609 166 - - - Cost - - - - 567 221 609 247 - - - Per l'utilita' data:\nsi suppone di considerare le seguenti\npriorita' o preferenze:\nSafety, Tollerabilita' rumore, Costo\nCio' implica che tutte le utilita' con Deaths=Safe devono essere maggiori di quelle con Deaths=Unsafe, che, specificato Deaths, quelle con Noise=tolerable devono essere maggiori di quelle con Noise=untolerable e in subordine quelle con Cost=low maggiori di quelle con Cost=high - - 28 32 221 200 - - - UD - - - - 716 72 750 104 - - - UN - - - - 723 153 757 185 - - - UC - - - - 741 231 775 263 - - - U - - - - 886 168 913 200 - - - - diff --git a/src/test/resources/AppleTree.xdsl b/src/test/resources/AppleTree.xdsl deleted file mode 100644 index 2416252..0000000 --- a/src/test/resources/AppleTree.xdsl +++ /dev/null @@ -1,92 +0,0 @@ - - - - - - - - 0.1 0.9 - - - - - 0.1 0.9 - - - - - Malato Secco - 0.95 0.05000000000000004 0.9 0.09999999999999998 0.85 0.15 0.02 0.98 - - - - - PerditaFoglie - - - Trattamento - -8000 0 - - - Malato - 3000 20000 - - - - - Malato Trattamento - 0.2 0.8 0.99 0.01 0.01000000000000001 0.99 0.02000000000000002 0.98 - - - Secco - 0.6 0.4 0.05000000000000004 0.95 - - - - - 20 30 819 509 - - Malato - - - - 228 205 302 251 - - - Secco - - - - 226 87 300 133 - - - PerditaFoglie - - - - 382 143 508 221 - - - Trattamento - - - - 311 304 424 338 - - - CostoT - - - - 327 382 418 472 - - - U - - - - 584 295 625 343 - - - - diff --git a/src/test/resources/Micromorti.xdsl b/src/test/resources/Micromorti.xdsl deleted file mode 100644 index 2bb8e7f..0000000 --- a/src/test/resources/Micromorti.xdsl +++ /dev/null @@ -1,89 +0,0 @@ - - - - - - - - 0.5 0.5 - - - - - Malattia - 0.99 0.01000000000000001 0.15 0.85 - - - - - Sintomo - - - - - - Test Malattia - 0.95 0.05000000000000004 0 0.01000000000000001 0.99 0 0 0 1 0 0 1 - - - - - RisultatoTest - - - Malattia Intervento - 0.9871992685296302 0 0.9992856734670552 1 - - - - - - Malattia - - - - 283 79 335 111 - - - Sintomo - - - - 285 185 339 218 - - - Test - - - - 416 189 449 215 - - - RisultatoTest - - - - 443 60 518 106 - - - Intervento - - - - 571 188 631 214 - - - U - - - - 448 306 475 338 - - - (MM) Micromorte=1 probabilita' su un millione di morire entro l'anno\nStima: 1 MM=$20\n\nUna probabilita' p e' quindi = p*10^6 MM\n(facendo una proporzione p:x=10^-6:1)\n\nAvere la malattia e non intervenire comporta un rischio pari ad una probabilita' del 35% di morire entro l'anno.\nCio' e' pari a 350.000 MM ossia $7Millioni\n\nIntervenire sulla malattia abbassa tale probabilita' per il nostro paziente a 0.0002=200MM=$4000\n\nSe il ns paziente non e' malato la sua aspettativa di vita e' tale che la prob. che muoia entro l'anno e' 1/50000=20MM=$400\n\nL'intervento non e' rischioso (quindi si tralascia di stimare quante micromorti vale l'intervento in assenza di malattia)\n\nSupponiamo che l'intervento costi $5000\n\nLa funzione di utilita' si puo' calcolare usando questi costi\nU(no M, no I)=-400\nU(no M, I)=-400-5000\nU(M, no I)=-7M\nU(M, I)=-4000-5000\n\nRinormalizzata a 1 e' come appare nell'ID\n - - 3 21 230 497 - - - - diff --git a/src/test/resources/lucas96simp.xdsl b/src/test/resources/lucas96simp.xdsl deleted file mode 100644 index b1cdccc..0000000 --- a/src/test/resources/lucas96simp.xdsl +++ /dev/null @@ -1,160 +0,0 @@ - - - - - - - - - 0.3333333333333333 0.3333333333333333 0.3333333333333334 - - - - - Age - 0.05 0.95 0.5 0.5 0.8 0.2 - -1 - - - - - Heart_failiure - 0.95 0.05000000000000004 0.1 0.9 - - - - - dyspnea - 0.98 0.02000000000000002 0.2 0.8 - - - - - dyspnea - 0.95 0.05000000000000004 0.1 0.9 - - - - - Heart_failiure - 0.7 0.3 0.15 0.85 - - - - - Age tachycpnea tachicardia PulmonaryCrepitations - -1 - - - - - treatment Heart_failiure Age - 0.6 0.4 0.45 0.55 0.65 0.35 0.1 0.9 0.15 0.85 0.18 0.8200000000000001 0.6 0.4 0.8 0.2 0.95 0.05000000000000004 0.1 0.9 0.15 0.85 0.18 0.8200000000000001 - -1 - - - - - intermdiate_result treatment - 0.5 0.5 0.95 0.05000000000000004 0.01 0.99 0.1 0.9 - - - treatment intermdiate_result Late_complications - 0 0.6 0.4 0.8 0.3 0.7 0.5 1 - - - - - - Age - - - - 599 419 640 444 - - - Heart_failiure - - - - 386 329 462 376 - - - dyspnea - - - - 387 197 480 221 - - - tachycpnea - - - - 385 102 451 143 - - - tachiycardia - - - - 543 203 609 244 - - - pulmonary crepitations - - - - 532 278 610 326 - - - aumento velocita' respirazione - - 333 86 479 100 - - - aumento velocita' cardiaca - - 513 187 641 201 - - - difficolta' a respirare - - 285 198 384 212 - - - treatment - - - - 713 177 770 203 - - - intermdiate result - - - - 619 643 712 701 - - - Late complications - - - - 863 432 962 493 - - - U - - - - 775 429 802 461 - - - Utilita' graduata considerando LateComplications piu' importante di IntermideiaResult piu' importante di costo trattamento - - 85 74 279 130 - - - -