001 /* 002 * Created on Jul 31, 2005 003 * 004 */ 005 package aima.learning.learners; 006 007 import java.util.Hashtable; 008 import java.util.List; 009 010 import aima.learning.framework.DataSet; 011 import aima.learning.framework.Example; 012 import aima.learning.framework.Learner; 013 import aima.util.Table; 014 import aima.util.Util; 015 016 /** 017 * @author Ravi Mohan 018 * 019 */ 020 public class AdaBoostLearner implements Learner { 021 022 private List<Learner> learners; 023 024 private DataSet dataSet; 025 026 private double[] exampleWeights; 027 028 private Hashtable<Learner, Double> learnerWeights; 029 030 public AdaBoostLearner(List<Learner> learners, DataSet ds) { 031 this.learners = learners; 032 this.dataSet = ds; 033 034 initializeExampleWeights(ds.examples.size()); 035 initializeHypothesisWeights(learners.size()); 036 } 037 038 public void train(DataSet ds) { 039 initializeExampleWeights(ds.examples.size()); 040 041 for (Learner learner : learners) { 042 learner.train(ds); 043 044 double error = calculateError(ds, learner); 045 if (error < 0.0001) { 046 break; 047 } 048 049 adjustExampleWeights(ds, learner, error); 050 051 double newHypothesisWeight = learnerWeights.get(learner) 052 * Math.log((1.0 - error) / error); 053 learnerWeights.put(learner, newHypothesisWeight); 054 } 055 056 } 057 058 public String predict(Example e) { 059 return weightedMajority(e); 060 } 061 062 private String weightedMajority(Example e) { 063 List<String> targetValues = dataSet.getPossibleAttributeValues(dataSet 064 .getTargetAttributeName()); 065 066 Table<String, Learner, Double> table = createTargetValueLearnerTable( 067 targetValues, e); 068 return getTargetValueWithTheMaximumVotes(targetValues, table); 069 } 070 071 private Table<String, Learner, Double> createTargetValueLearnerTable( 072 List<String> targetValues, Example e) { 073 // create a table with target-attribute values as rows and learners as 074 // columns and cells containing the weighted votes of each Learner for a 075 // target value 076 // Learner1 Learner2 Laerner3 ....... 077 // Yes 0.83 0.5 0 078 // No 0 0 0.6 079 080 Table<String, Learner, Double> table = new Table<String, Learner, Double>( 081 targetValues, learners); 082 // initialize table 083 for (Learner l : learners) { 084 for (String s : targetValues) { 085 table.set(s, l, 0.0); 086 } 087 } 088 for (Learner learner : learners) { 089 String predictedValue = learner.predict(e); 090 for (String v : targetValues) { 091 if (predictedValue.equals(v)) { 092 table.set(v, learner, table.get(v, learner) 093 + learnerWeights.get(learner) * 1); 094 } 095 } 096 } 097 return table; 098 } 099 100 private String getTargetValueWithTheMaximumVotes(List<String> targetValues, 101 Table<String, Learner, Double> table) { 102 String targetValueWithMaxScore = targetValues.get(0); 103 double score = scoreOfValue(targetValueWithMaxScore, table, learners); 104 for (String value : targetValues) { 105 double scoreOfValue = scoreOfValue(value, table, learners); 106 if (scoreOfValue > score) { 107 targetValueWithMaxScore = value; 108 score = scoreOfValue; 109 } 110 } 111 return targetValueWithMaxScore; 112 } 113 114 public int[] test(DataSet ds) { 115 int[] results = new int[] { 0, 0 }; 116 117 for (Example e : ds.examples) { 118 if (e.targetValue().equals(predict(e))) { 119 results[0] = results[0] + 1; 120 } else { 121 results[1] = results[1] + 1; 122 } 123 } 124 return results; 125 } 126 127 private void initializeExampleWeights(int size) { 128 if (size == 0) { 129 throw new RuntimeException( 130 "cannot initialize Ensemble learning with Empty Dataset"); 131 } 132 double value = 1.0 / (1.0 * size); 133 exampleWeights = new double[size]; 134 for (int i = 0; i < size; i++) { 135 exampleWeights[i] = value; 136 } 137 138 } 139 140 private void initializeHypothesisWeights(int size) { 141 if (size == 0) { 142 throw new RuntimeException( 143 "cannot initialize Ensemble learning with Zero Learners"); 144 } 145 146 learnerWeights = new Hashtable<Learner, Double>(); 147 for (Learner le : learners) { 148 learnerWeights.put(le, 1.0); 149 } 150 } 151 152 private double calculateError(DataSet ds, Learner l) { 153 double error = 0.0; 154 for (int i = 0; i < ds.examples.size(); i++) { 155 Example e = ds.getExample(i); 156 if (!(l.predict(e).equals(e.targetValue()))) { 157 error = error + exampleWeights[i]; 158 } 159 } 160 return error; 161 } 162 163 private void adjustExampleWeights(DataSet ds, Learner l, double error) { 164 double epsilon = error / (1.0 - error); 165 for (int j = 0; j < ds.examples.size(); j++) { 166 Example e = ds.getExample(j); 167 if ((l.predict(e).equals(e.targetValue()))) { 168 exampleWeights[j] = exampleWeights[j] * epsilon; 169 } 170 } 171 exampleWeights = Util.normalize(exampleWeights); 172 173 } 174 175 private double scoreOfValue(String targetValue, 176 Table<String, Learner, Double> table, List<Learner> learners) { 177 double score = 0.0; 178 for (Learner l : learners) { 179 score += table.get(targetValue, l); 180 } 181 return score; 182 } 183 184 }