001    /*
002     * Created on Jul 31, 2005
003     *
004     */
005    package aima.learning.learners;
006    
007    import java.util.Hashtable;
008    import java.util.List;
009    
010    import aima.learning.framework.DataSet;
011    import aima.learning.framework.Example;
012    import aima.learning.framework.Learner;
013    import aima.util.Table;
014    import aima.util.Util;
015    
016    /**
017     * @author Ravi Mohan
018     * 
019     */
020    public class AdaBoostLearner implements Learner {
021    
022            private List<Learner> learners;
023    
024            private DataSet dataSet;
025    
026            private double[] exampleWeights;
027    
028            private Hashtable<Learner, Double> learnerWeights;
029    
030            public AdaBoostLearner(List<Learner> learners, DataSet ds) {
031                    this.learners = learners;
032                    this.dataSet = ds;
033    
034                    initializeExampleWeights(ds.examples.size());
035                    initializeHypothesisWeights(learners.size());
036            }
037    
038            public void train(DataSet ds) {
039                    initializeExampleWeights(ds.examples.size());
040    
041                    for (Learner learner : learners) {
042                            learner.train(ds);
043    
044                            double error = calculateError(ds, learner);
045                            if (error < 0.0001) {
046                                    break;
047                            }
048    
049                            adjustExampleWeights(ds, learner, error);
050    
051                            double newHypothesisWeight = learnerWeights.get(learner)
052                                            * Math.log((1.0 - error) / error);
053                            learnerWeights.put(learner, newHypothesisWeight);
054                    }
055    
056            }
057    
058            public String predict(Example e) {
059                    return weightedMajority(e);
060            }
061    
062            private String weightedMajority(Example e) {
063                    List<String> targetValues = dataSet.getPossibleAttributeValues(dataSet
064                                    .getTargetAttributeName());
065    
066                    Table<String, Learner, Double> table = createTargetValueLearnerTable(
067                                    targetValues, e);
068                    return getTargetValueWithTheMaximumVotes(targetValues, table);
069            }
070    
071            private Table<String, Learner, Double> createTargetValueLearnerTable(
072                            List<String> targetValues, Example e) {
073                    // create a table with target-attribute values as rows and learners as
074                    // columns and cells containing the weighted votes of each Learner for a
075                    // target value
076                    // Learner1 Learner2 Laerner3 .......
077                    // Yes 0.83 0.5 0
078                    // No 0 0 0.6
079    
080                    Table<String, Learner, Double> table = new Table<String, Learner, Double>(
081                                    targetValues, learners);
082                    // initialize table
083                    for (Learner l : learners) {
084                            for (String s : targetValues) {
085                                    table.set(s, l, 0.0);
086                            }
087                    }
088                    for (Learner learner : learners) {
089                            String predictedValue = learner.predict(e);
090                            for (String v : targetValues) {
091                                    if (predictedValue.equals(v)) {
092                                            table.set(v, learner, table.get(v, learner)
093                                                            + learnerWeights.get(learner) * 1);
094                                    }
095                            }
096                    }
097                    return table;
098            }
099    
100            private String getTargetValueWithTheMaximumVotes(List<String> targetValues,
101                            Table<String, Learner, Double> table) {
102                    String targetValueWithMaxScore = targetValues.get(0);
103                    double score = scoreOfValue(targetValueWithMaxScore, table, learners);
104                    for (String value : targetValues) {
105                            double scoreOfValue = scoreOfValue(value, table, learners);
106                            if (scoreOfValue > score) {
107                                    targetValueWithMaxScore = value;
108                                    score = scoreOfValue;
109                            }
110                    }
111                    return targetValueWithMaxScore;
112            }
113    
114            public int[] test(DataSet ds) {
115                    int[] results = new int[] { 0, 0 };
116    
117                    for (Example e : ds.examples) {
118                            if (e.targetValue().equals(predict(e))) {
119                                    results[0] = results[0] + 1;
120                            } else {
121                                    results[1] = results[1] + 1;
122                            }
123                    }
124                    return results;
125            }
126    
127            private void initializeExampleWeights(int size) {
128                    if (size == 0) {
129                            throw new RuntimeException(
130                                            "cannot initialize Ensemble learning with Empty Dataset");
131                    }
132                    double value = 1.0 / (1.0 * size);
133                    exampleWeights = new double[size];
134                    for (int i = 0; i < size; i++) {
135                            exampleWeights[i] = value;
136                    }
137    
138            }
139    
140            private void initializeHypothesisWeights(int size) {
141                    if (size == 0) {
142                            throw new RuntimeException(
143                                            "cannot initialize Ensemble learning with Zero Learners");
144                    }
145    
146                    learnerWeights = new Hashtable<Learner, Double>();
147                    for (Learner le : learners) {
148                            learnerWeights.put(le, 1.0);
149                    }
150            }
151    
152            private double calculateError(DataSet ds, Learner l) {
153                    double error = 0.0;
154                    for (int i = 0; i < ds.examples.size(); i++) {
155                            Example e = ds.getExample(i);
156                            if (!(l.predict(e).equals(e.targetValue()))) {
157                                    error = error + exampleWeights[i];
158                            }
159                    }
160                    return error;
161            }
162    
163            private void adjustExampleWeights(DataSet ds, Learner l, double error) {
164                    double epsilon = error / (1.0 - error);
165                    for (int j = 0; j < ds.examples.size(); j++) {
166                            Example e = ds.getExample(j);
167                            if ((l.predict(e).equals(e.targetValue()))) {
168                                    exampleWeights[j] = exampleWeights[j] * epsilon;
169                            }
170                    }
171                    exampleWeights = Util.normalize(exampleWeights);
172    
173            }
174    
175            private double scoreOfValue(String targetValue,
176                            Table<String, Learner, Double> table, List<Learner> learners) {
177                    double score = 0.0;
178                    for (Learner l : learners) {
179                            score += table.get(targetValue, l);
180                    }
181                    return score;
182            }
183    
184    }