From 5d028fe22b789fe8958330d03edbd89b9ab990d0 Mon Sep 17 00:00:00 2001 From: Berack96 Date: Thu, 14 Dec 2023 01:30:27 +0100 Subject: [PATCH] InitMiniMax - added TrisAi - added TrisAi tests for heuristic - added MiniMax class (BETA) --- .../net/berack/upo/ai/problem2/MiniMax.java | 99 +++++++++++++++- .../net/berack/upo/ai/problem2/TrisAi.java | 82 ++++++++++++++ .../berack/upo/ai/problem2/TestTrisAi.java | 106 ++++++++++++++++++ 3 files changed, 282 insertions(+), 5 deletions(-) create mode 100644 src/main/java/net/berack/upo/ai/problem2/TrisAi.java create mode 100644 src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java diff --git a/src/main/java/net/berack/upo/ai/problem2/MiniMax.java b/src/main/java/net/berack/upo/ai/problem2/MiniMax.java index b780b3c..59150c8 100644 --- a/src/main/java/net/berack/upo/ai/problem2/MiniMax.java +++ b/src/main/java/net/berack/upo/ai/problem2/MiniMax.java @@ -1,14 +1,103 @@ package net.berack.upo.ai.problem2; +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** + * Algoritmo MiniMax per i giochi a due concorrenti. + * + * @author Berack96 + * @param State la classe degli stati in cui si trova il problema da risolvere + * @param Action la classe di azioni che si possono compiere da uno stato e l'altro + */ public class MiniMax { + private BiFunction transition; + private Function actions; + private Function maxGain; - - public Action nextMax(int depth) { - return null; + /** + * Crea una nuova istanza dell'algoritmo MiniMax per i giochi a due concorrenti. + * Questo costruttore richiede delle particolari funzioni in input: + * - transizione per passare da uno stato all'altro + * - actions per avere una lista di azioni da poter svolgere dato uno stato + * - maxGain per avere un gain dato uno stato (questa funzione DEVE guardare il gain dato solo da max, dato che min sarà -maxGain()) + * + * @param transition la funzione di transizione che permette al gioco di avanzare. + * @param actions le possibili azioni disponibili da uno stato. + * @param maxGain il gain che max ottiene ad essere in quello stato. + */ + public MiniMax(BiFunction transition, Function actions, Function maxGain) { + this.transition = Objects.requireNonNull(transition); + this.actions = Objects.requireNonNull(actions); + this.maxGain = Objects.requireNonNull(maxGain); } - public Action nextMin(int depth) { - return null; + /** + * Restituisce la migliore azione da fare dato lo stato corrente. + * Questo metodo ha un lookahead di mosse che guarderà e viene richiesto come parametro. + * Con valori bassi si potranno ottenere mosse non ottimali globalmente, + * mentre con valori alti si avranno mosse migliori ma il tempo di computazione + * sarà esponenzialmente peggiore. + * Si consigliano valori inferiori a 10. + * + * @param state lo stato corrente + * @param lookahead quante mosse guardare nel futuro + * @return la migliore mossa localmente + */ + public Action next(State state, int lookahead) { + return nextImpl(state, null, lookahead, true).action; + } + + /** + * Implementazione ricorsiva dell'algoritmo minimax + * @param state lo stato corrente + * @param action l'azione fatta per arrivarci + * @param depth la profondità a cui arrivare per controllare le mosse + * @param max se sta giocando max o min + * @return la migliore mossa locale con anche il valore guadagnato + */ + private BestAction nextImpl(State state, Action action, int depth, boolean max) { + var availableMoves = this.actions.apply(state); + BestAction best = null; + + if(depth == 0 || availableMoves.length == 0) { + var gain = this.maxGain.apply(state); + return new BestAction(action, max? gain : -gain); + } + + for(var move: availableMoves) { + var nextState = this.transition.apply(state, move); + var localBest = this.nextImpl(nextState, move, depth-1, !max); + + best = BestAction.getBest(best, localBest, max); + } + + if(action != null) best.action = action; + return best; + } + + + /** + * Classe di appoggio per restituire il gain e l'azione migliore localmente + * Ha anche un metodo statico utile per confrontare due azioni. + */ + private class BestAction { + Action action; + int value; + + BestAction(Action action, int value) { + this.action = action; + this.value = value; + } + + static BestAction getBest(BestAction n1, BestAction n2, boolean max) { + if(n1 == null) return n2; + if(n2 == null) return n1; + + var condition = (max? n1.value >= n2.value : n1.value <= n2.value); + return condition? n1:n2; + } } } diff --git a/src/main/java/net/berack/upo/ai/problem2/TrisAi.java b/src/main/java/net/berack/upo/ai/problem2/TrisAi.java new file mode 100644 index 0000000..48e23c6 --- /dev/null +++ b/src/main/java/net/berack/upo/ai/problem2/TrisAi.java @@ -0,0 +1,82 @@ +package net.berack.upo.ai.problem2; + +import static net.berack.upo.ai.problem2.Tris.Symbol.*; + +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** + * + * @author Berack96 + */ +public class TrisAi { + + public static final Function ACTIONS = tris -> tris.availablePlays(); + public static final BiFunction TRANSITION = (tris, coord) -> new Tris(tris, coord); + public static final Function GAIN = tris -> { + var symbol = tris.getNextPlaySymbol(); + var count = 0; + + // top left + count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,0), tris.get(2,0)); + count += TrisAi.value(symbol, tris.get(0,0), tris.get(0,1), tris.get(0,2)); + + // bottom right + count += TrisAi.value(symbol, tris.get(2,2), tris.get(1,2), tris.get(0,2)); + count += TrisAi.value(symbol, tris.get(2,2), tris.get(2,1), tris.get(2,0)); + + // center diagonals + count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,1), tris.get(2,2)); + count += TrisAi.value(symbol, tris.get(0,2), tris.get(1,1), tris.get(2,0)); + + // center horizontal & vertical + count += TrisAi.value(symbol, tris.get(0,1), tris.get(1,1), tris.get(2,1)); + count += TrisAi.value(symbol, tris.get(1,0), tris.get(1,1), tris.get(1,2)); + + // all calculation are done on the NEXT so for the current invert the value + return -count; + }; + + static int value(Tris.Symbol symbol, Tris.Symbol...values) { + var totE = 0; + var totO = 0; + var totX = 0; + + for(var val : values) switch(val) { + case VALUE_O: totO += 1; break; + case VALUE_X: totX += 1; break; + case EMPTY: totE += 1; break; + } + + var value = 0; + if(totO == 3 || totX == 3) value = 100; + else if(totE + totO == 3 || totE + totX == 3) value = switch(totE) { + case 1 -> 10; + case 2 -> 1; + default -> 0; + }; + + if((totO > totX && symbol == VALUE_X) || totX > totO && symbol == VALUE_O) + return -value; + return value; + } + + + private Tris tris; + private MiniMax minimax; + + public TrisAi(Tris tris) { + this.minimax = new MiniMax<>(TRANSITION, ACTIONS, GAIN); + this.tris = Objects.requireNonNull(tris); + } + + public void playNext() { + this.playNext(2); + } + + public void playNext(int lookahead) { + var action = minimax.next(this.tris, lookahead); + tris.play(action.x, action.y); + } +} diff --git a/src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java b/src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java new file mode 100644 index 0000000..9b1b85f --- /dev/null +++ b/src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java @@ -0,0 +1,106 @@ +package net.berack.upo.ai.problem2; + +import static net.berack.upo.ai.problem2.Tris.Symbol.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class TestTrisAi { + + @Test + public void testValue() { + + assertEquals(0, TrisAi.value(VALUE_O, VALUE_O, EMPTY, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_O, EMPTY, VALUE_O, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_O, EMPTY, VALUE_X, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_X, EMPTY, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_X, VALUE_O, EMPTY)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_O, VALUE_X, EMPTY)); + + assertEquals(0, TrisAi.value(VALUE_X, VALUE_O, EMPTY, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_X, EMPTY, VALUE_O, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_X, EMPTY, VALUE_X, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_X, EMPTY, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_X, VALUE_O, EMPTY)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_O, VALUE_X, EMPTY)); + + assertEquals(0, TrisAi.value(VALUE_O, VALUE_X, VALUE_X, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_X, VALUE_O, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_O, VALUE_X, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_O, VALUE_O, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_O, VALUE_X, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_O, VALUE_X, VALUE_O, VALUE_O)); + + assertEquals(0, TrisAi.value(VALUE_X, VALUE_X, VALUE_X, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_X, VALUE_O, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_O, VALUE_X, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_O, VALUE_O, VALUE_X)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_O, VALUE_X, VALUE_O)); + assertEquals(0, TrisAi.value(VALUE_X, VALUE_X, VALUE_O, VALUE_O)); + + assertEquals(0, TrisAi.value(VALUE_O, EMPTY, EMPTY, EMPTY)); + assertEquals(0, TrisAi.value(VALUE_X, EMPTY, EMPTY, EMPTY)); + + assertEquals(1, TrisAi.value(VALUE_O, VALUE_O, EMPTY, EMPTY)); + assertEquals(1, TrisAi.value(VALUE_O, EMPTY, VALUE_O, EMPTY)); + assertEquals(1, TrisAi.value(VALUE_O, EMPTY, EMPTY, VALUE_O)); + + assertEquals(1, TrisAi.value(VALUE_X, VALUE_X, EMPTY, EMPTY)); + assertEquals(1, TrisAi.value(VALUE_X, EMPTY, VALUE_X, EMPTY)); + assertEquals(1, TrisAi.value(VALUE_X, EMPTY, EMPTY, VALUE_X)); + + assertEquals(10, TrisAi.value(VALUE_O, EMPTY, VALUE_O, VALUE_O)); + assertEquals(10, TrisAi.value(VALUE_O, VALUE_O, EMPTY, VALUE_O)); + assertEquals(10, TrisAi.value(VALUE_O, VALUE_O, VALUE_O, EMPTY)); + + assertEquals(10, TrisAi.value(VALUE_X, EMPTY, VALUE_X, VALUE_X)); + assertEquals(10, TrisAi.value(VALUE_X, VALUE_X, EMPTY, VALUE_X)); + assertEquals(10, TrisAi.value(VALUE_X, VALUE_X, VALUE_X, EMPTY)); + + assertEquals(100, TrisAi.value(VALUE_O, VALUE_O, VALUE_O, VALUE_O)); + assertEquals(100, TrisAi.value(VALUE_X, VALUE_X, VALUE_X, VALUE_X)); + + + assertEquals(-1, TrisAi.value(VALUE_X, VALUE_O, EMPTY, EMPTY)); + assertEquals(-1, TrisAi.value(VALUE_X, EMPTY, VALUE_O, EMPTY)); + assertEquals(-1, TrisAi.value(VALUE_X, EMPTY, EMPTY, VALUE_O)); + + assertEquals(-1, TrisAi.value(VALUE_O, VALUE_X, EMPTY, EMPTY)); + assertEquals(-1, TrisAi.value(VALUE_O, EMPTY, VALUE_X, EMPTY)); + assertEquals(-1, TrisAi.value(VALUE_O, EMPTY, EMPTY, VALUE_X)); + + assertEquals(-10, TrisAi.value(VALUE_X, EMPTY, VALUE_O, VALUE_O)); + assertEquals(-10, TrisAi.value(VALUE_X, VALUE_O, EMPTY, VALUE_O)); + assertEquals(-10, TrisAi.value(VALUE_X, VALUE_O, VALUE_O, EMPTY)); + + assertEquals(-10, TrisAi.value(VALUE_O, EMPTY, VALUE_X, VALUE_X)); + assertEquals(-10, TrisAi.value(VALUE_O, VALUE_X, EMPTY, VALUE_X)); + assertEquals(-10, TrisAi.value(VALUE_O, VALUE_X, VALUE_X, EMPTY)); + + assertEquals(-100, TrisAi.value(VALUE_X, VALUE_O, VALUE_O, VALUE_O)); + assertEquals(-100, TrisAi.value(VALUE_O, VALUE_X, VALUE_X, VALUE_X)); + + } + + + @Test + public void testGain() { + var tris = new Tris(); + tris.play(0, 0); + assertEquals(3, TrisAi.GAIN.apply(tris)); + + tris.play(2, 2); + assertEquals(0, TrisAi.GAIN.apply(tris)); + + tris.play(1, 0); + assertEquals(10, TrisAi.GAIN.apply(tris)); + + tris.play(1, 2); + assertEquals(0, TrisAi.GAIN.apply(tris)); + + tris.play(2, 0); + assertEquals(92, TrisAi.GAIN.apply(tris)); + } + + +}