diff --git a/src/main/java/net/berack/upo/ai/problem2/MiniMax.java b/src/main/java/net/berack/upo/ai/problem2/MiniMax.java index 59150c8..278ba53 100644 --- a/src/main/java/net/berack/upo/ai/problem2/MiniMax.java +++ b/src/main/java/net/berack/upo/ai/problem2/MiniMax.java @@ -8,30 +8,31 @@ import java.util.function.Function; * Algoritmo MiniMax per i giochi a due concorrenti. * * @author Berack96 - * @param State la classe degli stati in cui si trova il problema da risolvere - * @param Action la classe di azioni che si possono compiere da uno stato e l'altro + * @param la classe degli stati in cui si trova il problema da risolvere + * @param la classe di azioni che si possono compiere da uno stato e l'altro + * @param la classe che indica il giocatore, essa serve da dare in input alla funzione di gain */ -public class MiniMax { +public class MiniMax { private BiFunction transition; + private BiFunction playerGain; private Function actions; - private Function maxGain; /** * Crea una nuova istanza dell'algoritmo MiniMax per i giochi a due concorrenti. * Questo costruttore richiede delle particolari funzioni in input: * - transizione per passare da uno stato all'altro * - actions per avere una lista di azioni da poter svolgere dato uno stato - * - maxGain per avere un gain dato uno stato (questa funzione DEVE guardare il gain dato solo da max, dato che min sarà -maxGain()) + * - playerGain per avere un gain dato uno stato (questa funzione DEVE guardare il gain dato solo dal giocatore in input, dato che per l'altro sarà -playerGain()) * * @param transition la funzione di transizione che permette al gioco di avanzare. * @param actions le possibili azioni disponibili da uno stato. - * @param maxGain il gain che max ottiene ad essere in quello stato. + * @param playerGain il gain che il giocatore ottiene ad essere in quello stato. */ - public MiniMax(BiFunction transition, Function actions, Function maxGain) { + public MiniMax(BiFunction transition, Function actions, BiFunction playerGain) { this.transition = Objects.requireNonNull(transition); this.actions = Objects.requireNonNull(actions); - this.maxGain = Objects.requireNonNull(maxGain); + this.playerGain = Objects.requireNonNull(playerGain); } /** @@ -43,61 +44,56 @@ public class MiniMax { * Si consigliano valori inferiori a 10. * * @param state lo stato corrente + * @param player il giocatore che deve fare la mossa * @param lookahead quante mosse guardare nel futuro * @return la migliore mossa localmente */ - public Action next(State state, int lookahead) { - return nextImpl(state, null, lookahead, true).action; + public Action next(State state, int lookahead, Player player) { + if(lookahead < 1) throw new IllegalArgumentException("Lookahead must be at least 1, otherwise it is useless"); + + Action best = null; + var bestGain = Integer.MIN_VALUE; + + for(var action : this.actions.apply(state)) { + var nextState = this.transition.apply(state, action); + var nextValue = this.expectedGain(nextState, player, lookahead-1, false, bestGain); + + if(nextValue > bestGain) { + bestGain = nextValue; + best = action; + } + } + + return best; } /** * Implementazione ricorsiva dell'algoritmo minimax * @param state lo stato corrente - * @param action l'azione fatta per arrivarci * @param depth la profondità a cui arrivare per controllare le mosse - * @param max se sta giocando max o min - * @return la migliore mossa locale con anche il valore guadagnato + * @param player il giocatore che dovrà scegliere la mossa migliore + * @param max se a questa profondità sta giocando max o min + * @return il guadagno maggiore locale */ - private BestAction nextImpl(State state, Action action, int depth, boolean max) { - var availableMoves = this.actions.apply(state); - BestAction best = null; + private int expectedGain(State state, Player player, int depth, boolean max, int currentBest) { + var value = this.playerGain.apply(state, player); + if(depth == 0) + return value; - if(depth == 0 || availableMoves.length == 0) { - var gain = this.maxGain.apply(state); - return new BestAction(action, max? gain : -gain); + var actions = this.actions.apply(state); + if(actions.length == 0) + return value; + + value = max? Integer.MIN_VALUE:Integer.MAX_VALUE; + + for(var action : actions) { + var nextState = this.transition.apply(state, action); + var nextValue = this.expectedGain(nextState, player, depth-1, !max, currentBest); + + var condition = (max? value > nextValue : value < nextValue); + value = condition? value:nextValue; } - for(var move: availableMoves) { - var nextState = this.transition.apply(state, move); - var localBest = this.nextImpl(nextState, move, depth-1, !max); - - best = BestAction.getBest(best, localBest, max); - } - - if(action != null) best.action = action; - return best; - } - - - /** - * Classe di appoggio per restituire il gain e l'azione migliore localmente - * Ha anche un metodo statico utile per confrontare due azioni. - */ - private class BestAction { - Action action; - int value; - - BestAction(Action action, int value) { - this.action = action; - this.value = value; - } - - static BestAction getBest(BestAction n1, BestAction n2, boolean max) { - if(n1 == null) return n2; - if(n2 == null) return n1; - - var condition = (max? n1.value >= n2.value : n1.value <= n2.value); - return condition? n1:n2; - } + return value; } } diff --git a/src/main/java/net/berack/upo/ai/problem2/Tris.java b/src/main/java/net/berack/upo/ai/problem2/Tris.java index d602611..d1922ae 100644 --- a/src/main/java/net/berack/upo/ai/problem2/Tris.java +++ b/src/main/java/net/berack/upo/ai/problem2/Tris.java @@ -21,7 +21,7 @@ public class Tris implements Iterable { public static class Coordinate { public final int x; public final int y; - private Coordinate(int x, int y) { + Coordinate(int x, int y) { this.x = x; this.y = y; } @@ -144,6 +144,21 @@ public class Tris implements Iterable { return res; } + /** + * Indica se il gioco è finito. + * Il gioco finisce se si ha un vincitore o se non ci sono più caselle vuote. + * + * @return vero se iol gioco è finito + */ + public boolean isFinished() { + if(haveWinner() != EMPTY) return true; + + for(var symbol : this.tris) + if(symbol == EMPTY) + return false; + return true; + } + /** * Indica se si ha un vincitore e restituisce chi ha vinto. * @return EMPTY se non c'è ancora un vincitore, altrimenti restituisci il vincitore @@ -208,6 +223,12 @@ public class Tris implements Iterable { return builder.toString(); } + @Override + public boolean equals(Object obj) { + if(!obj.getClass().isInstance(this)) return false; + return Arrays.equals(this.tris, ((Tris) obj).tris); + } + @Override public Iterator iterator() { return new Iterator() { diff --git a/src/main/java/net/berack/upo/ai/problem2/TrisAi.java b/src/main/java/net/berack/upo/ai/problem2/TrisAi.java index 48e23c6..0b91b3f 100644 --- a/src/main/java/net/berack/upo/ai/problem2/TrisAi.java +++ b/src/main/java/net/berack/upo/ai/problem2/TrisAi.java @@ -14,28 +14,26 @@ public class TrisAi { public static final Function ACTIONS = tris -> tris.availablePlays(); public static final BiFunction TRANSITION = (tris, coord) -> new Tris(tris, coord); - public static final Function GAIN = tris -> { - var symbol = tris.getNextPlaySymbol(); + public static final BiFunction GAIN = (tris, player) -> { var count = 0; // top left - count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,0), tris.get(2,0)); - count += TrisAi.value(symbol, tris.get(0,0), tris.get(0,1), tris.get(0,2)); + count += TrisAi.value(player, tris.get(0,0), tris.get(1,0), tris.get(2,0)); + count += TrisAi.value(player, tris.get(0,0), tris.get(0,1), tris.get(0,2)); // bottom right - count += TrisAi.value(symbol, tris.get(2,2), tris.get(1,2), tris.get(0,2)); - count += TrisAi.value(symbol, tris.get(2,2), tris.get(2,1), tris.get(2,0)); + count += TrisAi.value(player, tris.get(2,2), tris.get(1,2), tris.get(0,2)); + count += TrisAi.value(player, tris.get(2,2), tris.get(2,1), tris.get(2,0)); // center diagonals - count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,1), tris.get(2,2)); - count += TrisAi.value(symbol, tris.get(0,2), tris.get(1,1), tris.get(2,0)); + count += TrisAi.value(player, tris.get(0,0), tris.get(1,1), tris.get(2,2)); + count += TrisAi.value(player, tris.get(0,2), tris.get(1,1), tris.get(2,0)); // center horizontal & vertical - count += TrisAi.value(symbol, tris.get(0,1), tris.get(1,1), tris.get(2,1)); - count += TrisAi.value(symbol, tris.get(1,0), tris.get(1,1), tris.get(1,2)); + count += TrisAi.value(player, tris.get(0,1), tris.get(1,1), tris.get(2,1)); + count += TrisAi.value(player, tris.get(1,0), tris.get(1,1), tris.get(1,2)); - // all calculation are done on the NEXT so for the current invert the value - return -count; + return count; }; static int value(Tris.Symbol symbol, Tris.Symbol...values) { @@ -64,7 +62,7 @@ public class TrisAi { private Tris tris; - private MiniMax minimax; + private MiniMax minimax; public TrisAi(Tris tris) { this.minimax = new MiniMax<>(TRANSITION, ACTIONS, GAIN); @@ -76,7 +74,8 @@ public class TrisAi { } public void playNext(int lookahead) { - var action = minimax.next(this.tris, lookahead); + var myself = tris.getNextPlaySymbol(); + var action = minimax.next(this.tris, lookahead, myself); tris.play(action.x, action.y); } } diff --git a/src/test/java/net/berack/upo/ai/problem2/TestTris.java b/src/test/java/net/berack/upo/ai/problem2/TestTris.java index affaf59..c4c7c33 100644 --- a/src/test/java/net/berack/upo/ai/problem2/TestTris.java +++ b/src/test/java/net/berack/upo/ai/problem2/TestTris.java @@ -128,47 +128,98 @@ public class TestTris { // horizontal 1 line X var tris = new Tris(); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,0); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,1); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(0,0); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,2); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(2,0); - assertTrue(tris.haveWinner() == VALUE_X); + assertEquals(VALUE_X, tris.haveWinner()); + assertTrue(tris.isFinished()); // diagonal \ O tris = new Tris(); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(2,1); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,1); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(2,0); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(2,2); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,2); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(0,0); - assertTrue(tris.haveWinner() == VALUE_O); + assertEquals(VALUE_O, tris.haveWinner()); + assertTrue(tris.isFinished()); // vertical 2 column X tris = new Tris(); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,0); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(0,2); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,1); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(0,1); - assertTrue(tris.haveWinner() == EMPTY); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); tris.play(1,2); - assertTrue(tris.haveWinner() == VALUE_X); + assertEquals(VALUE_X, tris.haveWinner()); + assertTrue(tris.isFinished()); + + // No winner + tris = new Tris(); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(0,0); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(1,0); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(2,0); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(1,1); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(0,1); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(0,2); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(2,1); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(2,2); + assertEquals(EMPTY, tris.haveWinner()); + assertFalse(tris.isFinished()); + tris.play(1,2); + assertEquals(EMPTY, tris.haveWinner()); + assertTrue(tris.isFinished()); } } diff --git a/src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java b/src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java index 9b1b85f..473c623 100644 --- a/src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java +++ b/src/test/java/net/berack/upo/ai/problem2/TestTrisAi.java @@ -87,20 +87,48 @@ public class TestTrisAi { public void testGain() { var tris = new Tris(); tris.play(0, 0); - assertEquals(3, TrisAi.GAIN.apply(tris)); + assertEquals(3, TrisAi.GAIN.apply(tris, VALUE_X)); tris.play(2, 2); - assertEquals(0, TrisAi.GAIN.apply(tris)); + assertEquals(0, TrisAi.GAIN.apply(tris, VALUE_O)); tris.play(1, 0); - assertEquals(10, TrisAi.GAIN.apply(tris)); + assertEquals(10, TrisAi.GAIN.apply(tris, VALUE_X)); tris.play(1, 2); - assertEquals(0, TrisAi.GAIN.apply(tris)); + assertEquals(0, TrisAi.GAIN.apply(tris, VALUE_O)); tris.play(2, 0); - assertEquals(92, TrisAi.GAIN.apply(tris)); + assertEquals(92, TrisAi.GAIN.apply(tris, VALUE_X)); } + @Test + public void testNextEz() { + var tris = new Tris(); + var ai = new TrisAi(tris); + + tris.play(0,0); + tris.play(2,2); + tris.play(1,0); + + // block + var nx = new Tris(tris, new Tris.Coordinate(2,0)); + ai.playNext(); + assertEquals(nx, tris); + + // block 2 + nx = new Tris(tris, new Tris.Coordinate(2,1)); + ai.playNext(); + assertEquals(nx, tris); + + tris.play(1,1); + tris.play(0,1); + + // win + nx = new Tris(tris, new Tris.Coordinate(0,2)); + ai.playNext(); + assertEquals(nx, tris); + } + }