001    /*
002     * Created on Apr 9, 2005
003     *
004     */
005    package aima.learning.framework;
006    
007    /**
008     * @author Ravi Mohan
009     * 
010     */
011    import java.util.Hashtable;
012    import java.util.Iterator;
013    import java.util.LinkedList;
014    import java.util.List;
015    
016    import aima.util.Util;
017    
018    public class DataSet {
019            protected DataSet() {
020    
021            }
022    
023            public List<Example> examples;
024    
025            public DataSetSpecification specification;
026    
027            public DataSet(DataSetSpecification spec) {
028                    examples = new LinkedList<Example>();
029                    this.specification = spec;
030            }
031    
032            public void add(Example e) {
033                    examples.add(e);
034            }
035    
036            public int size() {
037                    return examples.size();
038            }
039    
040            public Example getExample(int number) {
041                    return examples.get(number);
042            }
043    
044            public DataSet removeExample(Example e) {
045                    DataSet ds = new DataSet(specification);
046                    for (Example eg : examples) {
047                            if (!(e.equals(eg))) {
048                                    ds.add(eg);
049                            }
050                    }
051                    return ds;
052            }
053    
054            public double getInformationFor() {
055                    String attributeName = specification.getTarget();
056                    Hashtable<String, Integer> counts = new Hashtable<String, Integer>();
057                    for (Example e : examples) {
058    
059                            String val = e.getAttributeValueAsString(attributeName);
060                            if (counts.containsKey(val)) {
061                                    counts.put(val, counts.get(val) + 1);
062                            } else {
063                                    counts.put(val, 1);
064                            }
065                    }
066    
067                    double totalTargetAttributeCount = 0;
068    
069                    double[] data = new double[counts.keySet().size()];
070                    Iterator<Integer> iter = counts.values().iterator();
071                    for (int i = 0; i < data.length; i++) {
072                            data[i] = iter.next();
073                    }
074                    data = Util.normalize(data);
075    
076                    return Util.information(data);
077            }
078    
079            public Hashtable<String, DataSet> splitByAttribute(String attributeName) {
080                    Hashtable<String, DataSet> results = new Hashtable<String, DataSet>();
081                    for (Example e : examples) {
082                            String val = e.getAttributeValueAsString(attributeName);
083                            if (results.containsKey(val)) {
084                                    results.get(val).add(e);
085                            } else {
086                                    DataSet ds = new DataSet(specification);
087                                    ds.add(e);
088                                    results.put(val, ds);
089                            }
090                    }
091                    return results;
092            }
093    
094            public double calculateGainFor(String parameterName) {
095                    Hashtable<String, DataSet> hash = splitByAttribute(parameterName);
096                    double totalSize = examples.size();
097                    double remainder = 0.0;
098                    for (String parameterValue : hash.keySet()) {
099                            double reducedDataSetSize = hash.get(parameterValue).examples
100                                            .size();
101                            remainder += (reducedDataSetSize / totalSize)
102                                            * hash.get(parameterValue).getInformationFor();
103                    }
104                    return getInformationFor() - remainder;
105            }
106    
107            @Override
108            public boolean equals(Object o) {
109                    if (this == o) {
110                            return true;
111                    }
112                    if ((o == null) || (this.getClass() != o.getClass())) {
113                            return false;
114                    }
115                    DataSet other = (DataSet) o;
116                    return examples.equals(other.examples);
117            }
118    
119            @Override
120            public int hashCode() {
121                    return 0;
122            }
123    
124            public Iterator<Example> iterator() {
125                    return examples.iterator();
126            }
127    
128            public DataSet copy() {
129                    DataSet ds = new DataSet(specification);
130                    for (Example e : examples) {
131                            ds.add(e);
132                    }
133                    return ds;
134            }
135    
136            public List<String> getAttributeNames() {
137                    return specification.getAttributeNames();
138            }
139    
140            public String getTargetAttributeName() {
141                    return specification.getTarget();
142            }
143    
144            public DataSet emptyDataSet() {
145                    return new DataSet(specification);
146            }
147    
148            /**
149             * @param specification
150             *            The specification to set. USE SPARINGLY for testing etc ..
151             *            makes no semantic sense
152             */
153            public void setSpecification(DataSetSpecification specification) {
154                    this.specification = specification;
155            }
156    
157            public List<String> getPossibleAttributeValues(String attributeName) {
158                    return specification.getPossibleAttributeValues(attributeName);
159            }
160    
161            public DataSet matchingDataSet(String attributeName, String attributeValue) {
162                    DataSet ds = new DataSet(specification);
163                    for (Example e : examples) {
164                            if (e.getAttributeValueAsString(attributeName).equals(
165                                            attributeValue)) {
166                                    ds.add(e);
167                            }
168                    }
169                    return ds;
170            }
171    
172            public List<String> getNonTargetAttributes() {
173                    return Util.removeFrom(getAttributeNames(), getTargetAttributeName());
174            }
175    
176    }