001 package aima.probability.decision.cellworld; 002 003 import java.util.ArrayList; 004 import java.util.Arrays; 005 import java.util.Hashtable; 006 import java.util.List; 007 008 import aima.probability.Randomizer; 009 import aima.probability.decision.MDP; 010 import aima.probability.decision.MDPPerception; 011 import aima.probability.decision.MDPRewardFunction; 012 import aima.probability.decision.MDPSource; 013 import aima.probability.decision.MDPTransitionModel; 014 import aima.util.Pair; 015 016 /** 017 * @author Ravi Mohan 018 * 019 */ 020 021 public class CellWorld implements MDPSource<CellWorldPosition, String> { 022 public static final String LEFT = "left"; 023 024 public static final String RIGHT = "right"; 025 026 public static final String UP = "up"; 027 028 public static final String DOWN = "down"; 029 030 public static final String NO_OP = "no_op"; 031 032 List<Cell> blockedCells, allCells; 033 034 private int numberOfRows; 035 036 private int numberOfColumns; 037 038 private List<Cell> terminalStates; 039 040 private Cell initialState; 041 042 public CellWorld(int numberOfRows, int numberOfColumns, double initialReward) { 043 allCells = new ArrayList<Cell>(); 044 blockedCells = new ArrayList<Cell>(); 045 046 terminalStates = new ArrayList<Cell>(); 047 048 this.numberOfRows = numberOfRows; 049 this.numberOfColumns = numberOfColumns; 050 051 for (int row = 1; row <= numberOfRows; row++) { 052 for (int col = 1; col <= numberOfColumns; col++) { 053 allCells.add(new Cell(row, col, initialReward)); 054 } 055 } 056 057 initialState = getCellAt(1, 4); 058 } 059 060 public void markBlocked(int i, int j) { 061 blockedCells.add(getCellAt(i, j)); 062 063 } 064 065 private boolean isBlocked(int i, int j) { 066 if ((i < 1) || (i > numberOfRows) || (j < 1) || (j > numberOfColumns)) { 067 return true; 068 } 069 for (Cell c : blockedCells) { 070 if ((c.getX() == i) && (c.getY() == j)) { 071 return true; 072 } 073 } 074 return false; 075 } 076 077 private Cell getCellAt(int i, int j) { 078 for (Cell c : allCells) { 079 if ((c.getX() == i) && (c.getY() == j)) { 080 return c; 081 } 082 } 083 throw new RuntimeException("No Cell found at " + i + " , " + j); 084 } 085 086 public CellWorldPosition moveProbabilisticallyFrom(int i, int j, 087 String direction, Randomizer r) { 088 Cell c = getCellAt(i, j); 089 if (terminalStates.contains(c)) { 090 return c.position(); 091 } 092 return moveFrom(i, j, determineDirectionOfActualMovement(direction, r)); 093 094 } 095 096 private CellWorldPosition moveFrom(int i, int j, String direction) { 097 if (direction.equals(LEFT)) { 098 return moveLeftFrom(i, j); 099 } 100 if (direction.equals(RIGHT)) { 101 return moveRightFrom(i, j); 102 } 103 if (direction.equals(UP)) { 104 return moveUpFrom(i, j); 105 } 106 if (direction.equals(DOWN)) { 107 return moveDownFrom(i, j); 108 } 109 throw new RuntimeException("Unable to move " + direction + " from " + i 110 + " , " + j); 111 } 112 113 private CellWorldPosition moveFrom(CellWorldPosition startingPosition, 114 String direction) { 115 return moveFrom(startingPosition.getX(), startingPosition.getY(), 116 direction); 117 } 118 119 private String determineDirectionOfActualMovement( 120 String commandedDirection, double prob) { 121 if (prob < 0.8) { 122 return commandedDirection; 123 } else if ((prob > 0.8) && (prob < 0.9)) { 124 if ((commandedDirection.equals(LEFT)) 125 || (commandedDirection.equals(RIGHT))) { 126 return UP; 127 } 128 if ((commandedDirection.equals(UP)) 129 || (commandedDirection.equals(DOWN))) { 130 return LEFT; 131 } 132 } else { // 0.9 < prob < 1.0 133 if ((commandedDirection.equals(LEFT)) 134 || (commandedDirection.equals(RIGHT))) { 135 return DOWN; 136 } 137 if ((commandedDirection.equals(UP)) 138 || (commandedDirection.equals(DOWN))) { 139 return RIGHT; 140 } 141 } 142 throw new RuntimeException( 143 "Unable to determine direction when command = " 144 + commandedDirection + " and probability = " + prob); 145 146 } 147 148 private String determineDirectionOfActualMovement( 149 String commandedDirection, Randomizer r) { 150 return determineDirectionOfActualMovement(commandedDirection, r 151 .nextDouble()); 152 153 } 154 155 private CellWorldPosition moveLeftFrom(int i, int j) { 156 if (isBlocked(i, j - 1)) { 157 return new CellWorldPosition(i, j); 158 } 159 return new CellWorldPosition(i, j - 1); 160 } 161 162 private CellWorldPosition moveRightFrom(int i, int j) { 163 if (isBlocked(i, j + 1)) { 164 return new CellWorldPosition(i, j); 165 } 166 return new CellWorldPosition(i, j + 1); 167 } 168 169 private CellWorldPosition moveUpFrom(int i, int j) { 170 if (isBlocked(i + 1, j)) { 171 return new CellWorldPosition(i, j); 172 } 173 return new CellWorldPosition(i + 1, j); 174 } 175 176 private CellWorldPosition moveDownFrom(int i, int j) { 177 if (isBlocked(i - 1, j)) { 178 return new CellWorldPosition(i, j); 179 } 180 return new CellWorldPosition(i - 1, j); 181 } 182 183 public void setReward(int i, int j, double reward) { 184 Cell c = getCellAt(i, j); 185 c.setReward(reward); 186 187 } 188 189 public List<Cell> unblockedCells() { 190 List<Cell> res = new ArrayList<Cell>(); 191 for (Cell c : allCells) { 192 if (!(blockedCells.contains(c))) { 193 res.add(c); 194 } 195 } 196 return res; 197 } 198 199 public boolean isBlocked(Pair<Integer, Integer> p) { 200 return isBlocked(p.getFirst(), p.getSecond()); 201 } 202 203 // what is the probability of starting from position p1 taking action a and 204 // reaaching position p2 205 // method is public ONLY for testing do not use in client code. 206 public double getTransitionProbability(CellWorldPosition startingPosition, 207 String actionDesired, CellWorldPosition endingPosition) { 208 209 String firstRightAngledAction = determineDirectionOfActualMovement( 210 actionDesired, 0.85); 211 String secondRightAngledAction = determineDirectionOfActualMovement( 212 actionDesired, 0.95); 213 214 Hashtable<String, CellWorldPosition> actionsToPositions = new Hashtable<String, CellWorldPosition>(); 215 actionsToPositions.put(actionDesired, moveFrom(startingPosition, 216 actionDesired)); 217 actionsToPositions.put(firstRightAngledAction, moveFrom( 218 startingPosition, firstRightAngledAction)); 219 actionsToPositions.put(secondRightAngledAction, moveFrom( 220 startingPosition, secondRightAngledAction)); 221 222 Hashtable<CellWorldPosition, Double> positionsToProbability = new Hashtable<CellWorldPosition, Double>(); 223 for (CellWorldPosition p : actionsToPositions.values()) { 224 positionsToProbability.put(p, 0.0); 225 } 226 227 for (String action : actionsToPositions.keySet()) { 228 CellWorldPosition position = actionsToPositions.get(action); 229 double value = positionsToProbability.get(position); 230 if (action.equals(actionDesired)) { 231 positionsToProbability.put(position, value + 0.8); 232 } else { // right angled steps 233 positionsToProbability.put(position, value + 0.1); 234 } 235 236 } 237 238 if (positionsToProbability.keySet().contains(endingPosition)) { 239 return positionsToProbability.get(endingPosition); 240 } else { 241 return 0.0; 242 } 243 244 } 245 246 public MDPTransitionModel<CellWorldPosition, String> getTransitionModel() { 247 List<CellWorldPosition> terminalPositions = new ArrayList<CellWorldPosition>(); 248 for (Cell tc : terminalStates) { 249 terminalPositions.add(tc.position()); 250 } 251 MDPTransitionModel<CellWorldPosition, String> mtm = new MDPTransitionModel<CellWorldPosition, String>( 252 terminalPositions); 253 254 List<String> actions = Arrays.asList(new String[] { UP, DOWN, LEFT, 255 RIGHT }); 256 257 for (CellWorldPosition startingPosition : getNonFinalStates()) { 258 for (String actionDesired : actions) { 259 for (Cell target : unblockedCells()) { // too much work? should 260 // just cycle through 261 // neighbouring cells 262 // instead of all cells. 263 CellWorldPosition endingPosition = target.position(); 264 double transitionProbability = getTransitionProbability( 265 startingPosition, actionDesired, endingPosition); 266 if (!(transitionProbability == 0.0)) { 267 268 mtm.setTransitionProbability(startingPosition, 269 actionDesired, endingPosition, 270 transitionProbability); 271 } 272 } 273 } 274 } 275 return mtm; 276 } 277 278 public MDPRewardFunction<CellWorldPosition> getRewardFunction() { 279 280 MDPRewardFunction<CellWorldPosition> result = new MDPRewardFunction<CellWorldPosition>(); 281 for (Cell c : unblockedCells()) { 282 CellWorldPosition pos = c.position(); 283 double reward = c.getReward(); 284 result.setReward(pos, reward); 285 } 286 287 return result; 288 } 289 290 public List<CellWorldPosition> unblockedPositions() { 291 List<CellWorldPosition> result = new ArrayList<CellWorldPosition>(); 292 for (Cell c : unblockedCells()) { 293 result.add(c.position()); 294 } 295 return result; 296 } 297 298 public MDP<CellWorldPosition, String> asMdp() { 299 300 return new MDP<CellWorldPosition, String>(this); 301 } 302 303 public List<CellWorldPosition> getNonFinalStates() { 304 List<CellWorldPosition> nonFinalPositions = unblockedPositions(); 305 nonFinalPositions.remove(getCellAt(2, 4).position()); 306 nonFinalPositions.remove(getCellAt(3, 4).position()); 307 return nonFinalPositions; 308 } 309 310 public List<CellWorldPosition> getFinalStates() { 311 List<CellWorldPosition> finalPositions = new ArrayList<CellWorldPosition>(); 312 finalPositions.add(getCellAt(2, 4).position()); 313 finalPositions.add(getCellAt(3, 4).position()); 314 return finalPositions; 315 } 316 317 public void setTerminalState(int i, int j) { 318 setTerminalState(new CellWorldPosition(i, j)); 319 320 } 321 322 public void setTerminalState(CellWorldPosition position) { 323 terminalStates.add(getCellAt(position.getX(), position.getY())); 324 325 } 326 327 public CellWorldPosition getInitialState() { 328 return initialState.position(); 329 } 330 331 public MDPPerception<CellWorldPosition> execute(CellWorldPosition position, 332 String action, Randomizer r) { 333 CellWorldPosition pos = moveProbabilisticallyFrom(position.getX(), 334 position.getY(), action, r); 335 double reward = getCellAt(pos.getX(), pos.getY()).getReward(); 336 return new MDPPerception<CellWorldPosition>(pos, reward); 337 } 338 339 public List<String> getAllActions() { 340 341 return Arrays.asList(new String[] { LEFT, RIGHT, UP, DOWN }); 342 } 343 344 }