TrisAi now works

- fixed MiniMax algorithm
- refactored MiniMax class
- added Tris check if finished & equals
- added edge cases to trisTest
- added test cases to trisTestAI
This commit is contained in:
2023-12-14 15:54:42 +01:00
parent 5d028fe22b
commit a310a013c4
5 changed files with 185 additions and 90 deletions

View File

@@ -8,30 +8,31 @@ import java.util.function.Function;
* Algoritmo MiniMax per i giochi a due concorrenti.
*
* @author Berack96
* @param State la classe degli stati in cui si trova il problema da risolvere
* @param Action la classe di azioni che si possono compiere da uno stato e l'altro
* @param <State> la classe degli stati in cui si trova il problema da risolvere
* @param <Action> la classe di azioni che si possono compiere da uno stato e l'altro
* @param <Player> la classe che indica il giocatore, essa serve da dare in input alla funzione di gain
*/
public class MiniMax<State, Action> {
public class MiniMax<State, Action, Player> {
private BiFunction<State, Action, State> transition;
private BiFunction<State, Player, Integer> playerGain;
private Function<State, Action[]> actions;
private Function<State, Integer> maxGain;
/**
* Crea una nuova istanza dell'algoritmo MiniMax per i giochi a due concorrenti.
* Questo costruttore richiede delle particolari funzioni in input:
* - transizione per passare da uno stato all'altro
* - actions per avere una lista di azioni da poter svolgere dato uno stato
* - maxGain per avere un gain dato uno stato (questa funzione DEVE guardare il gain dato solo da max, dato che min sarà -maxGain())
* - playerGain per avere un gain dato uno stato (questa funzione DEVE guardare il gain dato solo dal giocatore in input, dato che per l'altro sarà -playerGain())
*
* @param transition la funzione di transizione che permette al gioco di avanzare.
* @param actions le possibili azioni disponibili da uno stato.
* @param maxGain il gain che max ottiene ad essere in quello stato.
* @param playerGain il gain che il giocatore ottiene ad essere in quello stato.
*/
public MiniMax(BiFunction<State, Action, State> transition, Function<State, Action[]> actions, Function<State, Integer> maxGain) {
public MiniMax(BiFunction<State, Action, State> transition, Function<State, Action[]> actions, BiFunction<State, Player, Integer> playerGain) {
this.transition = Objects.requireNonNull(transition);
this.actions = Objects.requireNonNull(actions);
this.maxGain = Objects.requireNonNull(maxGain);
this.playerGain = Objects.requireNonNull(playerGain);
}
/**
@@ -43,61 +44,56 @@ public class MiniMax<State, Action> {
* Si consigliano valori inferiori a 10.
*
* @param state lo stato corrente
* @param player il giocatore che deve fare la mossa
* @param lookahead quante mosse guardare nel futuro
* @return la migliore mossa localmente
*/
public Action next(State state, int lookahead) {
return nextImpl(state, null, lookahead, true).action;
public Action next(State state, int lookahead, Player player) {
if(lookahead < 1) throw new IllegalArgumentException("Lookahead must be at least 1, otherwise it is useless");
Action best = null;
var bestGain = Integer.MIN_VALUE;
for(var action : this.actions.apply(state)) {
var nextState = this.transition.apply(state, action);
var nextValue = this.expectedGain(nextState, player, lookahead-1, false, bestGain);
if(nextValue > bestGain) {
bestGain = nextValue;
best = action;
}
}
return best;
}
/**
* Implementazione ricorsiva dell'algoritmo minimax
* @param state lo stato corrente
* @param action l'azione fatta per arrivarci
* @param depth la profondità a cui arrivare per controllare le mosse
* @param max se sta giocando max o min
* @return la migliore mossa locale con anche il valore guadagnato
* @param player il giocatore che dovrà scegliere la mossa migliore
* @param max se a questa profondità sta giocando max o min
* @return il guadagno maggiore locale
*/
private BestAction nextImpl(State state, Action action, int depth, boolean max) {
var availableMoves = this.actions.apply(state);
BestAction best = null;
private int expectedGain(State state, Player player, int depth, boolean max, int currentBest) {
var value = this.playerGain.apply(state, player);
if(depth == 0)
return value;
if(depth == 0 || availableMoves.length == 0) {
var gain = this.maxGain.apply(state);
return new BestAction(action, max? gain : -gain);
var actions = this.actions.apply(state);
if(actions.length == 0)
return value;
value = max? Integer.MIN_VALUE:Integer.MAX_VALUE;
for(var action : actions) {
var nextState = this.transition.apply(state, action);
var nextValue = this.expectedGain(nextState, player, depth-1, !max, currentBest);
var condition = (max? value > nextValue : value < nextValue);
value = condition? value:nextValue;
}
for(var move: availableMoves) {
var nextState = this.transition.apply(state, move);
var localBest = this.nextImpl(nextState, move, depth-1, !max);
best = BestAction.getBest(best, localBest, max);
}
if(action != null) best.action = action;
return best;
}
/**
* Classe di appoggio per restituire il gain e l'azione migliore localmente
* Ha anche un metodo statico utile per confrontare due azioni.
*/
private class BestAction {
Action action;
int value;
BestAction(Action action, int value) {
this.action = action;
this.value = value;
}
static BestAction getBest(BestAction n1, BestAction n2, boolean max) {
if(n1 == null) return n2;
if(n2 == null) return n1;
var condition = (max? n1.value >= n2.value : n1.value <= n2.value);
return condition? n1:n2;
}
return value;
}
}

View File

@@ -21,7 +21,7 @@ public class Tris implements Iterable<Tris.Symbol> {
public static class Coordinate {
public final int x;
public final int y;
private Coordinate(int x, int y) {
Coordinate(int x, int y) {
this.x = x;
this.y = y;
}
@@ -144,6 +144,21 @@ public class Tris implements Iterable<Tris.Symbol> {
return res;
}
/**
* Indica se il gioco è finito.
* Il gioco finisce se si ha un vincitore o se non ci sono più caselle vuote.
*
* @return vero se iol gioco è finito
*/
public boolean isFinished() {
if(haveWinner() != EMPTY) return true;
for(var symbol : this.tris)
if(symbol == EMPTY)
return false;
return true;
}
/**
* Indica se si ha un vincitore e restituisce chi ha vinto.
* @return EMPTY se non c'è ancora un vincitore, altrimenti restituisci il vincitore
@@ -208,6 +223,12 @@ public class Tris implements Iterable<Tris.Symbol> {
return builder.toString();
}
@Override
public boolean equals(Object obj) {
if(!obj.getClass().isInstance(this)) return false;
return Arrays.equals(this.tris, ((Tris) obj).tris);
}
@Override
public Iterator<Symbol> iterator() {
return new Iterator<Symbol>() {

View File

@@ -14,28 +14,26 @@ public class TrisAi {
public static final Function<Tris, Tris.Coordinate[]> ACTIONS = tris -> tris.availablePlays();
public static final BiFunction<Tris, Tris.Coordinate, Tris> TRANSITION = (tris, coord) -> new Tris(tris, coord);
public static final Function<Tris, Integer> GAIN = tris -> {
var symbol = tris.getNextPlaySymbol();
public static final BiFunction<Tris, Tris.Symbol, Integer> GAIN = (tris, player) -> {
var count = 0;
// top left
count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,0), tris.get(2,0));
count += TrisAi.value(symbol, tris.get(0,0), tris.get(0,1), tris.get(0,2));
count += TrisAi.value(player, tris.get(0,0), tris.get(1,0), tris.get(2,0));
count += TrisAi.value(player, tris.get(0,0), tris.get(0,1), tris.get(0,2));
// bottom right
count += TrisAi.value(symbol, tris.get(2,2), tris.get(1,2), tris.get(0,2));
count += TrisAi.value(symbol, tris.get(2,2), tris.get(2,1), tris.get(2,0));
count += TrisAi.value(player, tris.get(2,2), tris.get(1,2), tris.get(0,2));
count += TrisAi.value(player, tris.get(2,2), tris.get(2,1), tris.get(2,0));
// center diagonals
count += TrisAi.value(symbol, tris.get(0,0), tris.get(1,1), tris.get(2,2));
count += TrisAi.value(symbol, tris.get(0,2), tris.get(1,1), tris.get(2,0));
count += TrisAi.value(player, tris.get(0,0), tris.get(1,1), tris.get(2,2));
count += TrisAi.value(player, tris.get(0,2), tris.get(1,1), tris.get(2,0));
// center horizontal & vertical
count += TrisAi.value(symbol, tris.get(0,1), tris.get(1,1), tris.get(2,1));
count += TrisAi.value(symbol, tris.get(1,0), tris.get(1,1), tris.get(1,2));
count += TrisAi.value(player, tris.get(0,1), tris.get(1,1), tris.get(2,1));
count += TrisAi.value(player, tris.get(1,0), tris.get(1,1), tris.get(1,2));
// all calculation are done on the NEXT so for the current invert the value
return -count;
return count;
};
static int value(Tris.Symbol symbol, Tris.Symbol...values) {
@@ -64,7 +62,7 @@ public class TrisAi {
private Tris tris;
private MiniMax<Tris, Tris.Coordinate> minimax;
private MiniMax<Tris, Tris.Coordinate, Tris.Symbol> minimax;
public TrisAi(Tris tris) {
this.minimax = new MiniMax<>(TRANSITION, ACTIONS, GAIN);
@@ -76,7 +74,8 @@ public class TrisAi {
}
public void playNext(int lookahead) {
var action = minimax.next(this.tris, lookahead);
var myself = tris.getNextPlaySymbol();
var action = minimax.next(this.tris, lookahead, myself);
tris.play(action.x, action.y);
}
}

View File

@@ -128,47 +128,98 @@ public class TestTris {
// horizontal 1 line X
var tris = new Tris();
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,0);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,0);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,0);
assertTrue(tris.haveWinner() == VALUE_X);
assertEquals(VALUE_X, tris.haveWinner());
assertTrue(tris.isFinished());
// diagonal \ O
tris = new Tris();
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,1);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,0);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,2);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,0);
assertTrue(tris.haveWinner() == VALUE_O);
assertEquals(VALUE_O, tris.haveWinner());
assertTrue(tris.isFinished());
// vertical 2 column X
tris = new Tris();
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,0);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,2);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,1);
assertTrue(tris.haveWinner() == EMPTY);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2);
assertTrue(tris.haveWinner() == VALUE_X);
assertEquals(VALUE_X, tris.haveWinner());
assertTrue(tris.isFinished());
// No winner
tris = new Tris();
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,0);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,0);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,0);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,1);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,1);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(0,2);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,1);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(2,2);
assertEquals(EMPTY, tris.haveWinner());
assertFalse(tris.isFinished());
tris.play(1,2);
assertEquals(EMPTY, tris.haveWinner());
assertTrue(tris.isFinished());
}
}

View File

@@ -87,20 +87,48 @@ public class TestTrisAi {
public void testGain() {
var tris = new Tris();
tris.play(0, 0);
assertEquals(3, TrisAi.GAIN.apply(tris));
assertEquals(3, TrisAi.GAIN.apply(tris, VALUE_X));
tris.play(2, 2);
assertEquals(0, TrisAi.GAIN.apply(tris));
assertEquals(0, TrisAi.GAIN.apply(tris, VALUE_O));
tris.play(1, 0);
assertEquals(10, TrisAi.GAIN.apply(tris));
assertEquals(10, TrisAi.GAIN.apply(tris, VALUE_X));
tris.play(1, 2);
assertEquals(0, TrisAi.GAIN.apply(tris));
assertEquals(0, TrisAi.GAIN.apply(tris, VALUE_O));
tris.play(2, 0);
assertEquals(92, TrisAi.GAIN.apply(tris));
assertEquals(92, TrisAi.GAIN.apply(tris, VALUE_X));
}
@Test
public void testNextEz() {
var tris = new Tris();
var ai = new TrisAi(tris);
tris.play(0,0);
tris.play(2,2);
tris.play(1,0);
// block
var nx = new Tris(tris, new Tris.Coordinate(2,0));
ai.playNext();
assertEquals(nx, tris);
// block 2
nx = new Tris(tris, new Tris.Coordinate(2,1));
ai.playNext();
assertEquals(nx, tris);
tris.play(1,1);
tris.play(0,1);
// win
nx = new Tris(tris, new Tris.Coordinate(0,2));
ai.playNext();
assertEquals(nx, tris);
}
}