001    package aima.probability.decision.cellworld;
002    
003    import java.util.ArrayList;
004    import java.util.Arrays;
005    import java.util.Hashtable;
006    import java.util.List;
007    
008    import aima.probability.Randomizer;
009    import aima.probability.decision.MDP;
010    import aima.probability.decision.MDPPerception;
011    import aima.probability.decision.MDPRewardFunction;
012    import aima.probability.decision.MDPSource;
013    import aima.probability.decision.MDPTransitionModel;
014    import aima.util.Pair;
015    
016    /**
017     * @author Ravi Mohan
018     * 
019     */
020    
021    public class CellWorld implements MDPSource<CellWorldPosition, String> {
022            public static final String LEFT = "left";
023    
024            public static final String RIGHT = "right";
025    
026            public static final String UP = "up";
027    
028            public static final String DOWN = "down";
029    
030            public static final String NO_OP = "no_op";
031    
032            List<Cell> blockedCells, allCells;
033    
034            private int numberOfRows;
035    
036            private int numberOfColumns;
037    
038            private List<Cell> terminalStates;
039    
040            private Cell initialState;
041    
042            public CellWorld(int numberOfRows, int numberOfColumns, double initialReward) {
043                    allCells = new ArrayList<Cell>();
044                    blockedCells = new ArrayList<Cell>();
045    
046                    terminalStates = new ArrayList<Cell>();
047    
048                    this.numberOfRows = numberOfRows;
049                    this.numberOfColumns = numberOfColumns;
050    
051                    for (int row = 1; row <= numberOfRows; row++) {
052                            for (int col = 1; col <= numberOfColumns; col++) {
053                                    allCells.add(new Cell(row, col, initialReward));
054                            }
055                    }
056    
057                    initialState = getCellAt(1, 4);
058            }
059    
060            public void markBlocked(int i, int j) {
061                    blockedCells.add(getCellAt(i, j));
062    
063            }
064    
065            private boolean isBlocked(int i, int j) {
066                    if ((i < 1) || (i > numberOfRows) || (j < 1) || (j > numberOfColumns)) {
067                            return true;
068                    }
069                    for (Cell c : blockedCells) {
070                            if ((c.getX() == i) && (c.getY() == j)) {
071                                    return true;
072                            }
073                    }
074                    return false;
075            }
076    
077            private Cell getCellAt(int i, int j) {
078                    for (Cell c : allCells) {
079                            if ((c.getX() == i) && (c.getY() == j)) {
080                                    return c;
081                            }
082                    }
083                    throw new RuntimeException("No Cell found at " + i + " , " + j);
084            }
085    
086            public CellWorldPosition moveProbabilisticallyFrom(int i, int j,
087                            String direction, Randomizer r) {
088                    Cell c = getCellAt(i, j);
089                    if (terminalStates.contains(c)) {
090                            return c.position();
091                    }
092                    return moveFrom(i, j, determineDirectionOfActualMovement(direction, r));
093    
094            }
095    
096            private CellWorldPosition moveFrom(int i, int j, String direction) {
097                    if (direction.equals(LEFT)) {
098                            return moveLeftFrom(i, j);
099                    }
100                    if (direction.equals(RIGHT)) {
101                            return moveRightFrom(i, j);
102                    }
103                    if (direction.equals(UP)) {
104                            return moveUpFrom(i, j);
105                    }
106                    if (direction.equals(DOWN)) {
107                            return moveDownFrom(i, j);
108                    }
109                    throw new RuntimeException("Unable to move " + direction + " from " + i
110                                    + " , " + j);
111            }
112    
113            private CellWorldPosition moveFrom(CellWorldPosition startingPosition,
114                            String direction) {
115                    return moveFrom(startingPosition.getX(), startingPosition.getY(),
116                                    direction);
117            }
118    
119            private String determineDirectionOfActualMovement(
120                            String commandedDirection, double prob) {
121                    if (prob < 0.8) {
122                            return commandedDirection;
123                    } else if ((prob > 0.8) && (prob < 0.9)) {
124                            if ((commandedDirection.equals(LEFT))
125                                            || (commandedDirection.equals(RIGHT))) {
126                                    return UP;
127                            }
128                            if ((commandedDirection.equals(UP))
129                                            || (commandedDirection.equals(DOWN))) {
130                                    return LEFT;
131                            }
132                    } else { // 0.9 < prob < 1.0
133                            if ((commandedDirection.equals(LEFT))
134                                            || (commandedDirection.equals(RIGHT))) {
135                                    return DOWN;
136                            }
137                            if ((commandedDirection.equals(UP))
138                                            || (commandedDirection.equals(DOWN))) {
139                                    return RIGHT;
140                            }
141                    }
142                    throw new RuntimeException(
143                                    "Unable to determine direction when command =  "
144                                                    + commandedDirection + " and probability = " + prob);
145    
146            }
147    
148            private String determineDirectionOfActualMovement(
149                            String commandedDirection, Randomizer r) {
150                    return determineDirectionOfActualMovement(commandedDirection, r
151                                    .nextDouble());
152    
153            }
154    
155            private CellWorldPosition moveLeftFrom(int i, int j) {
156                    if (isBlocked(i, j - 1)) {
157                            return new CellWorldPosition(i, j);
158                    }
159                    return new CellWorldPosition(i, j - 1);
160            }
161    
162            private CellWorldPosition moveRightFrom(int i, int j) {
163                    if (isBlocked(i, j + 1)) {
164                            return new CellWorldPosition(i, j);
165                    }
166                    return new CellWorldPosition(i, j + 1);
167            }
168    
169            private CellWorldPosition moveUpFrom(int i, int j) {
170                    if (isBlocked(i + 1, j)) {
171                            return new CellWorldPosition(i, j);
172                    }
173                    return new CellWorldPosition(i + 1, j);
174            }
175    
176            private CellWorldPosition moveDownFrom(int i, int j) {
177                    if (isBlocked(i - 1, j)) {
178                            return new CellWorldPosition(i, j);
179                    }
180                    return new CellWorldPosition(i - 1, j);
181            }
182    
183            public void setReward(int i, int j, double reward) {
184                    Cell c = getCellAt(i, j);
185                    c.setReward(reward);
186    
187            }
188    
189            public List<Cell> unblockedCells() {
190                    List<Cell> res = new ArrayList<Cell>();
191                    for (Cell c : allCells) {
192                            if (!(blockedCells.contains(c))) {
193                                    res.add(c);
194                            }
195                    }
196                    return res;
197            }
198    
199            public boolean isBlocked(Pair<Integer, Integer> p) {
200                    return isBlocked(p.getFirst(), p.getSecond());
201            }
202    
203            // what is the probability of starting from position p1 taking action a and
204            // reaaching position p2
205            // method is public ONLY for testing do not use in client code.
206            public double getTransitionProbability(CellWorldPosition startingPosition,
207                            String actionDesired, CellWorldPosition endingPosition) {
208    
209                    String firstRightAngledAction = determineDirectionOfActualMovement(
210                                    actionDesired, 0.85);
211                    String secondRightAngledAction = determineDirectionOfActualMovement(
212                                    actionDesired, 0.95);
213    
214                    Hashtable<String, CellWorldPosition> actionsToPositions = new Hashtable<String, CellWorldPosition>();
215                    actionsToPositions.put(actionDesired, moveFrom(startingPosition,
216                                    actionDesired));
217                    actionsToPositions.put(firstRightAngledAction, moveFrom(
218                                    startingPosition, firstRightAngledAction));
219                    actionsToPositions.put(secondRightAngledAction, moveFrom(
220                                    startingPosition, secondRightAngledAction));
221    
222                    Hashtable<CellWorldPosition, Double> positionsToProbability = new Hashtable<CellWorldPosition, Double>();
223                    for (CellWorldPosition p : actionsToPositions.values()) {
224                            positionsToProbability.put(p, 0.0);
225                    }
226    
227                    for (String action : actionsToPositions.keySet()) {
228                            CellWorldPosition position = actionsToPositions.get(action);
229                            double value = positionsToProbability.get(position);
230                            if (action.equals(actionDesired)) {
231                                    positionsToProbability.put(position, value + 0.8);
232                            } else { // right angled steps
233                                    positionsToProbability.put(position, value + 0.1);
234                            }
235    
236                    }
237    
238                    if (positionsToProbability.keySet().contains(endingPosition)) {
239                            return positionsToProbability.get(endingPosition);
240                    } else {
241                            return 0.0;
242                    }
243    
244            }
245    
246            public MDPTransitionModel<CellWorldPosition, String> getTransitionModel() {
247                    List<CellWorldPosition> terminalPositions = new ArrayList<CellWorldPosition>();
248                    for (Cell tc : terminalStates) {
249                            terminalPositions.add(tc.position());
250                    }
251                    MDPTransitionModel<CellWorldPosition, String> mtm = new MDPTransitionModel<CellWorldPosition, String>(
252                                    terminalPositions);
253    
254                    List<String> actions = Arrays.asList(new String[] { UP, DOWN, LEFT,
255                                    RIGHT });
256    
257                    for (CellWorldPosition startingPosition : getNonFinalStates()) {
258                            for (String actionDesired : actions) {
259                                    for (Cell target : unblockedCells()) { // too much work? should
260                                            // just cycle through
261                                            // neighbouring cells
262                                            // instead of all cells.
263                                            CellWorldPosition endingPosition = target.position();
264                                            double transitionProbability = getTransitionProbability(
265                                                            startingPosition, actionDesired, endingPosition);
266                                            if (!(transitionProbability == 0.0)) {
267    
268                                                    mtm.setTransitionProbability(startingPosition,
269                                                                    actionDesired, endingPosition,
270                                                                    transitionProbability);
271                                            }
272                                    }
273                            }
274                    }
275                    return mtm;
276            }
277    
278            public MDPRewardFunction<CellWorldPosition> getRewardFunction() {
279    
280                    MDPRewardFunction<CellWorldPosition> result = new MDPRewardFunction<CellWorldPosition>();
281                    for (Cell c : unblockedCells()) {
282                            CellWorldPosition pos = c.position();
283                            double reward = c.getReward();
284                            result.setReward(pos, reward);
285                    }
286    
287                    return result;
288            }
289    
290            public List<CellWorldPosition> unblockedPositions() {
291                    List<CellWorldPosition> result = new ArrayList<CellWorldPosition>();
292                    for (Cell c : unblockedCells()) {
293                            result.add(c.position());
294                    }
295                    return result;
296            }
297    
298            public MDP<CellWorldPosition, String> asMdp() {
299    
300                    return new MDP<CellWorldPosition, String>(this);
301            }
302    
303            public List<CellWorldPosition> getNonFinalStates() {
304                    List<CellWorldPosition> nonFinalPositions = unblockedPositions();
305                    nonFinalPositions.remove(getCellAt(2, 4).position());
306                    nonFinalPositions.remove(getCellAt(3, 4).position());
307                    return nonFinalPositions;
308            }
309    
310            public List<CellWorldPosition> getFinalStates() {
311                    List<CellWorldPosition> finalPositions = new ArrayList<CellWorldPosition>();
312                    finalPositions.add(getCellAt(2, 4).position());
313                    finalPositions.add(getCellAt(3, 4).position());
314                    return finalPositions;
315            }
316    
317            public void setTerminalState(int i, int j) {
318                    setTerminalState(new CellWorldPosition(i, j));
319    
320            }
321    
322            public void setTerminalState(CellWorldPosition position) {
323                    terminalStates.add(getCellAt(position.getX(), position.getY()));
324    
325            }
326    
327            public CellWorldPosition getInitialState() {
328                    return initialState.position();
329            }
330    
331            public MDPPerception<CellWorldPosition> execute(CellWorldPosition position,
332                            String action, Randomizer r) {
333                    CellWorldPosition pos = moveProbabilisticallyFrom(position.getX(),
334                                    position.getY(), action, r);
335                    double reward = getCellAt(pos.getX(), pos.getY()).getReward();
336                    return new MDPPerception<CellWorldPosition>(pos, reward);
337            }
338    
339            public List<String> getAllActions() {
340    
341                    return Arrays.asList(new String[] { LEFT, RIGHT, UP, DOWN });
342            }
343    
344    }