TrisAi now works

- fixed MiniMax algorithm
- refactored MiniMax class
- added Tris check if finished & equals
- added edge cases to trisTest
- added test cases to trisTestAI
This commit is contained in:
2023-12-14 15:54:42 +01:00
parent 5d028fe22b
commit a310a013c4
5 changed files with 185 additions and 90 deletions

View File

@@ -8,30 +8,31 @@ import java.util.function.Function;
* Algoritmo MiniMax per i giochi a due concorrenti. * Algoritmo MiniMax per i giochi a due concorrenti.
* *
* @author Berack96 * @author Berack96
* @param State la classe degli stati in cui si trova il problema da risolvere * @param <State> la classe degli stati in cui si trova il problema da risolvere
* @param Action la classe di azioni che si possono compiere da uno stato e l'altro * @param <Action> la classe di azioni che si possono compiere da uno stato e l'altro
* @param <Player> la classe che indica il giocatore, essa serve da dare in input alla funzione di gain
*/ */
public class MiniMax<State, Action> { public class MiniMax<State, Action, Player> {
private BiFunction<State, Action, State> transition; private BiFunction<State, Action, State> transition;
private BiFunction<State, Player, Integer> playerGain;
private Function<State, Action[]> actions; private Function<State, Action[]> actions;
private Function<State, Integer> maxGain;
/** /**
* Crea una nuova istanza dell'algoritmo MiniMax per i giochi a due concorrenti. * Crea una nuova istanza dell'algoritmo MiniMax per i giochi a due concorrenti.
* Questo costruttore richiede delle particolari funzioni in input: * Questo costruttore richiede delle particolari funzioni in input:
* - transizione per passare da uno stato all'altro * - transizione per passare da uno stato all'altro
* - actions per avere una lista di azioni da poter svolgere dato uno stato * - actions per avere una lista di azioni da poter svolgere dato uno stato
* - maxGain per avere un gain dato uno stato (questa funzione DEVE guardare il gain dato solo da max, dato che min sarà -maxGain()) * - playerGain per avere un gain dato uno stato (questa funzione DEVE guardare il gain dato solo dal giocatore in input, dato che per l'altro sarà -playerGain())
* *
* @param transition la funzione di transizione che permette al gioco di avanzare. * @param transition la funzione di transizione che permette al gioco di avanzare.
* @param actions le possibili azioni disponibili da uno stato. * @param actions le possibili azioni disponibili da uno stato.
* @param maxGain il gain che max ottiene ad essere in quello stato. * @param playerGain il gain che il giocatore ottiene ad essere in quello stato.
*/ */
public MiniMax(BiFunction<State, Action, State> transition, Function<State, Action[]> actions, Function<State, Integer> maxGain) { public MiniMax(BiFunction<State, Action, State> transition, Function<State, Action[]> actions, BiFunction<State, Player, Integer> playerGain) {
this.transition = Objects.requireNonNull(transition); this.transition = Objects.requireNonNull(transition);
this.actions = Objects.requireNonNull(actions); this.actions = Objects.requireNonNull(actions);
this.maxGain = Objects.requireNonNull(maxGain); this.playerGain = Objects.requireNonNull(playerGain);
} }
/** /**
@@ -43,61 +44,56 @@ public class MiniMax<State, Action> {
* Si consigliano valori inferiori a 10. * Si consigliano valori inferiori a 10.
* *
* @param state lo stato corrente * @param state lo stato corrente
* @param player il giocatore che deve fare la mossa
* @param lookahead quante mosse guardare nel futuro * @param lookahead quante mosse guardare nel futuro
* @return la migliore mossa localmente * @return la migliore mossa localmente
*/ */
public Action next(State state, int lookahead) { public Action next(State state, int lookahead, Player player) {
return nextImpl(state, null, lookahead, true).action; if(lookahead < 1) throw new IllegalArgumentException("Lookahead must be at least 1, otherwise it is useless");
Action best = null;
var bestGain = Integer.MIN_VALUE;
for(var action : this.actions.apply(state)) {
var nextState = this.transition.apply(state, action);
var nextValue = this.expectedGain(nextState, player, lookahead-1, false, bestGain);
if(nextValue > bestGain) {
bestGain = nextValue;
best = action;
}
}
return best;
} }
/** /**
* Implementazione ricorsiva dell'algoritmo minimax * Implementazione ricorsiva dell'algoritmo minimax
* @param state lo stato corrente * @param state lo stato corrente
* @param action l'azione fatta per arrivarci
* @param depth la profondità a cui arrivare per controllare le mosse * @param depth la profondità a cui arrivare per controllare le mosse
* @param max se sta giocando max o min * @param player il giocatore che dovrà scegliere la mossa migliore
* @return la migliore mossa locale con anche il valore guadagnato * @param max se a questa profondità sta giocando max o min
* @return il guadagno maggiore locale
*/ */
private BestAction nextImpl(State state, Action action, int depth, boolean max) { private int expectedGain(State state, Player player, int depth, boolean max, int currentBest) {
var availableMoves = this.actions.apply(state); var value = this.playerGain.apply(state, player);
BestAction best = null; if(depth == 0)
return value;
if(depth == 0 || availableMoves.length == 0) { var actions = this.actions.apply(state);
var gain = this.maxGain.apply(state); if(actions.length == 0)
return new BestAction(action, max? gain : -gain); return value;
value = max? Integer.MIN_VALUE:Integer.MAX_VALUE;
for(var action : actions) {
var nextState = this.transition.apply(state, action);
var nextValue = this.expectedGain(nextState, player, depth-1, !max, currentBest);
var condition = (max? value > nextValue : value < nextValue);
value = condition? value:nextValue;
} }
for(var move: availableMoves) { return value;
var nextState = this.transition.apply(state, move);
var localBest = this.nextImpl(nextState, move, depth-1, !max);
best = BestAction.getBest(best, localBest, max);
}
if(action != null) best.action = action;
return best;
}
/**
* Classe di appoggio per restituire il gain e l'azione migliore localmente
* Ha anche un metodo statico utile per confrontare due azioni.
*/
private class BestAction {
Action action;
int value;
BestAction(Action action, int value) {
this.action = action;
this.value = value;
}
static BestAction getBest(BestAction n1, BestAction n2, boolean max) {
if(n1 == null) return n2;
if(n2 == null) return n1;
var condition = (max? n1.value >= n2.value : n1.value <= n2.value);
return condition? n1:n2;
}
} }
} }

View File

@@ -21,7 +21,7 @@ public class Tris implements Iterable<Tris.Symbol> {
public static class Coordinate { public static class Coordinate {
public final int x; public final int x;
public final int y; public final int y;
private Coordinate(int x, int y) { Coordinate(int x, int y) {
this.x = x; this.x = x;
this.y = y; this.y = y;
} }
@@ -144,6 +144,21 @@ public class Tris implements Iterable<Tris.Symbol> {
return res; return res;
} }
/**
* Indica se il gioco è finito.
* Il gioco finisce se si ha un vincitore o se non ci sono più caselle vuote.
*
* @return vero se iol gioco è finito
*/
public boolean isFinished() {
if(haveWinner() != EMPTY) return true;
for(var symbol : this.tris)
if(symbol == EMPTY)
return false;
return true;
}
/** /**
* Indica se si ha un vincitore e restituisce chi ha vinto. * Indica se si ha un vincitore e restituisce chi ha vinto.
* @return EMPTY se non c'è ancora un vincitore, altrimenti restituisci il vincitore * @return EMPTY se non c'è ancora un vincitore, altrimenti restituisci il vincitore
@@ -208,6 +223,12 @@ public class Tris implements Iterable<Tris.Symbol> {
return builder.toString(); return builder.toString();
} }
@Override
public boolean equals(Object obj) {
if(!obj.getClass().isInstance(this)) return false;
return Arrays.equals(this.tris, ((Tris) obj).tris);
}
@Override @Override
public Iterator<Symbol> iterator() { public Iterator<Symbol> iterator() {
return new Iterator<Symbol>() { return new Iterator<Symbol>() {

View File

@@ -14,28 +14,26 @@ public class TrisAi {
public static final Function<Tris, Tris.Coordinate[]> ACTIONS = tris -> tris.availablePlays(); public static final Function<Tris, Tris.Coordinate[]> ACTIONS = tris -> tris.availablePlays();
public static final BiFunction<Tris, Tris.Coordinate, Tris> TRANSITION = (tris, coord) -> new Tris(tris, coord); public static final BiFunction<Tris, Tris.Coordinate, Tris> TRANSITION = (tris, coord) -> new Tris(tris, coord);
public static final Function<Tris, Integer> GAIN = tris -> { public static final BiFunction<Tris, Tris.Symbol, Integer> GAIN = (tris, player) -> {
var symbol = tris.getNextPlaySymbol();
var count = 0; var count = 0;
// top left // top left
count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,0), tris.get(2,0)); count += TrisAi.value(player, tris.get(0,0), tris.get(1,0), tris.get(2,0));
count += TrisAi.value(symbol, tris.get(0,0), tris.get(0,1), tris.get(0,2)); count += TrisAi.value(player, tris.get(0,0), tris.get(0,1), tris.get(0,2));
// bottom right // bottom right
count += TrisAi.value(symbol, tris.get(2,2), tris.get(1,2), tris.get(0,2)); count += TrisAi.value(player, tris.get(2,2), tris.get(1,2), tris.get(0,2));
count += TrisAi.value(symbol, tris.get(2,2), tris.get(2,1), tris.get(2,0)); count += TrisAi.value(player, tris.get(2,2), tris.get(2,1), tris.get(2,0));
// center diagonals // center diagonals
count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,1), tris.get(2,2)); count += TrisAi.value(player, tris.get(0,0), tris.get(1,1), tris.get(2,2));
count += TrisAi.value(symbol, tris.get(0,2), tris.get(1,1), tris.get(2,0)); count += TrisAi.value(player, tris.get(0,2), tris.get(1,1), tris.get(2,0));
// center horizontal & vertical // center horizontal & vertical
count += TrisAi.value(symbol, tris.get(0,1), tris.get(1,1), tris.get(2,1)); count += TrisAi.value(player, tris.get(0,1), tris.get(1,1), tris.get(2,1));
count += TrisAi.value(symbol, tris.get(1,0), tris.get(1,1), tris.get(1,2)); count += TrisAi.value(player, tris.get(1,0), tris.get(1,1), tris.get(1,2));
// all calculation are done on the NEXT so for the current invert the value return count;
return -count;
}; };
static int value(Tris.Symbol symbol, Tris.Symbol...values) { static int value(Tris.Symbol symbol, Tris.Symbol...values) {
@@ -64,7 +62,7 @@ public class TrisAi {
private Tris tris; private Tris tris;
private MiniMax<Tris, Tris.Coordinate> minimax; private MiniMax<Tris, Tris.Coordinate, Tris.Symbol> minimax;
public TrisAi(Tris tris) { public TrisAi(Tris tris) {
this.minimax = new MiniMax<>(TRANSITION, ACTIONS, GAIN); this.minimax = new MiniMax<>(TRANSITION, ACTIONS, GAIN);
@@ -76,7 +74,8 @@ public class TrisAi {
} }
public void playNext(int lookahead) { public void playNext(int lookahead) {
var action = minimax.next(this.tris, lookahead); var myself = tris.getNextPlaySymbol();
var action = minimax.next(this.tris, lookahead, myself);
tris.play(action.x, action.y); tris.play(action.x, action.y);
} }
} }

View File

@@ -128,47 +128,98 @@ public class TestTris {
// horizontal 1 line X // horizontal 1 line X
var tris = new Tris(); var tris = new Tris();
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,0); tris.play(1,0);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1); tris.play(1,1);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,0); tris.play(0,0);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2); tris.play(1,2);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,0); tris.play(2,0);
assertTrue(tris.haveWinner() == VALUE_X); assertEquals(VALUE_X, tris.haveWinner());
assertTrue(tris.isFinished());
// diagonal \ O // diagonal \ O
tris = new Tris(); tris = new Tris();
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,1); tris.play(2,1);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1); tris.play(1,1);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,0); tris.play(2,0);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,2); tris.play(2,2);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2); tris.play(1,2);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,0); tris.play(0,0);
assertTrue(tris.haveWinner() == VALUE_O); assertEquals(VALUE_O, tris.haveWinner());
assertTrue(tris.isFinished());
// vertical 2 column X // vertical 2 column X
tris = new Tris(); tris = new Tris();
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,0); tris.play(1,0);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,2); tris.play(0,2);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1); tris.play(1,1);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,1); tris.play(0,1);
assertTrue(tris.haveWinner() == EMPTY); assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2); tris.play(1,2);
assertTrue(tris.haveWinner() == VALUE_X); assertEquals(VALUE_X, tris.haveWinner());
assertTrue(tris.isFinished());
// No winner
tris = new Tris();
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,0);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,0);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,0);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,1);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,2);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,1);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,2);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2);
assertEquals(EMPTY, tris.haveWinner());
assertTrue(tris.isFinished());
} }
} }

View File

@@ -87,20 +87,48 @@ public class TestTrisAi {
public void testGain() { public void testGain() {
var tris = new Tris(); var tris = new Tris();
tris.play(0, 0); tris.play(0, 0);
assertEquals(3, TrisAi.GAIN.apply(tris)); assertEquals(3, TrisAi.GAIN.apply(tris, VALUE_X));
tris.play(2, 2); tris.play(2, 2);
assertEquals(0, TrisAi.GAIN.apply(tris)); assertEquals(0, TrisAi.GAIN.apply(tris, VALUE_O));
tris.play(1, 0); tris.play(1, 0);
assertEquals(10, TrisAi.GAIN.apply(tris)); assertEquals(10, TrisAi.GAIN.apply(tris, VALUE_X));
tris.play(1, 2); tris.play(1, 2);
assertEquals(0, TrisAi.GAIN.apply(tris)); assertEquals(0, TrisAi.GAIN.apply(tris, VALUE_O));
tris.play(2, 0); tris.play(2, 0);
assertEquals(92, TrisAi.GAIN.apply(tris)); assertEquals(92, TrisAi.GAIN.apply(tris, VALUE_X));
} }
@Test
public void testNextEz() {
var tris = new Tris();
var ai = new TrisAi(tris);
tris.play(0,0);
tris.play(2,2);
tris.play(1,0);
// block
var nx = new Tris(tris, new Tris.Coordinate(2,0));
ai.playNext();
assertEquals(nx, tris);
// block 2
nx = new Tris(tris, new Tris.Coordinate(2,1));
ai.playNext();
assertEquals(nx, tris);
tris.play(1,1);
tris.play(0,1);
// win
nx = new Tris(tris, new Tris.Coordinate(0,2));
ai.playNext();
assertEquals(nx, tris);
}
} }