001    /*
002     * Created on Jul 25, 2005
003     *
004     */
005    package aima.test.learningtest;
006    
007    import java.util.Hashtable;
008    
009    import junit.framework.TestCase;
010    import aima.learning.framework.DataSet;
011    import aima.learning.framework.DataSetFactory;
012    import aima.util.Util;
013    
014    /**
015     * @author Ravi Mohan
016     * 
017     */
018    
019    public class InformationAndGainTest extends TestCase {
020            public void testInformationCalculation() {
021                    double[] fairCoinProbabilities = new double[] { 0.5, 0.5 };
022                    double[] loadedCoinProbabilities = new double[] { 0.01, 0.99 };
023    
024                    assertEquals(1.0, Util.information(fairCoinProbabilities));
025                    assertEquals(0.08079313589591118, Util
026                                    .information(loadedCoinProbabilities));
027            }
028    
029            public void testBasicDataSetInformationCalculation() throws Exception {
030                    DataSet ds = DataSetFactory.getRestaurantDataSet();
031                    double infoForTargetAttribute = ds.getInformationFor();// this should
032                    // be the
033                    // generic
034                    // distribution
035                    assertEquals(1.0, infoForTargetAttribute);
036            }
037    
038            public void testDataSetSplit() throws Exception {
039                    DataSet ds = DataSetFactory.getRestaurantDataSet();
040                    Hashtable<String, DataSet> hash = ds.splitByAttribute("patrons");// this
041                    // should
042                    // be
043                    // the
044                    // generic
045                    // distribution
046                    assertEquals(3, hash.keySet().size());
047                    assertEquals(6, hash.get("Full").size());
048                    assertEquals(2, hash.get("None").size());
049                    assertEquals(4, hash.get("Some").size());
050    
051            }
052    
053            public void testGainCalculation() throws Exception {
054                    DataSet ds = DataSetFactory.getRestaurantDataSet();
055                    Hashtable<String, DataSet> hash = ds.splitByAttribute("patrons");
056                    double gain = ds.calculateGainFor("patrons");
057                    assertEquals(0.541, gain, 0.001);
058                    gain = ds.calculateGainFor("type");
059                    assertEquals(0.0, gain, 0.001);
060            }
061    
062    }