001    /*
002     * Created on Apr 9, 2005
003     *
004     */
005    package aima.test.learningtest;
006    
007    import java.io.FileNotFoundException;
008    import java.util.Arrays;
009    import java.util.List;
010    
011    import junit.framework.TestCase;
012    import aima.learning.framework.DataSet;
013    import aima.learning.framework.DataSetFactory;
014    import aima.learning.framework.DataSetSpecification;
015    import aima.learning.framework.Example;
016    import aima.learning.neural.IrisDataSetNumerizer;
017    import aima.learning.neural.Numerizer;
018    import aima.util.Pair;
019    
020    /**
021     * @author Ravi Mohan
022     * 
023     */
024    public class DataSetTest extends TestCase {
025            private static final String NO = "No";
026    
027            private static final String YES = "Yes";
028    
029            DataSetSpecification spec;
030    
031            public void testLoadsDatasetFile() throws Exception {
032    
033                    DataSet ds = DataSetFactory.getRestaurantDataSet();
034                    assertEquals(12, ds.size());
035    
036                    Example first = ds.getExample(0);
037                    assertEquals(YES, first.getAttributeValueAsString("alternate"));
038                    assertEquals("$$$", first.getAttributeValueAsString("price"));
039                    assertEquals("0-10", first.getAttributeValueAsString("wait_estimate"));
040                    assertEquals(YES, first.getAttributeValueAsString("will_wait"));
041                    assertEquals(YES, first.targetValue());
042            }
043    
044            public void testThrowsExceptionForNonExistentFile()
045                            throws FileNotFoundException {
046                    try {
047                            DataSet ds = new DataSetFactory().fromFile("nonexistent", null,
048                                            null);
049                            fail("should have thrown Exception");
050                    } catch (Exception ex) {
051    
052                    }
053    
054            }
055    
056            public void testLoadsIrisDataSetWithNumericAndStringAttributes()
057                            throws Exception {
058                    DataSet ds = DataSetFactory.getIrisDataSet();
059                    Example first = ds.getExample(0);
060                    assertEquals("5.1", first.getAttributeValueAsString("sepal_length"));
061            }
062    
063            public void testNonDestructiveRemoveExample() throws Exception {
064                    DataSet ds1 = DataSetFactory.getRestaurantDataSet();
065                    DataSet ds2 = ds1.removeExample(ds1.getExample(0));
066                    assertEquals(12, ds1.size());
067                    assertEquals(11, ds2.size());
068            }
069    
070            public void testNumerizesAndDeNumerizesIrisDataSetExample1()
071                            throws Exception {
072                    DataSet ds = DataSetFactory.getIrisDataSet();
073                    Example first = ds.getExample(0);
074                    Numerizer n = new IrisDataSetNumerizer();
075                    Pair<List<Double>, List<Double>> io = n.numerize(first);
076    
077                    assertEquals(Arrays.asList(5.1, 3.5, 1.4, 0.2), io.getFirst());
078                    assertEquals(Arrays.asList(0.0, 0.0, 1.0), io.getSecond());
079    
080                    String plant_category = n.denumerize(Arrays.asList(0.0, 0.0, 1.0));
081                    assertEquals("setosa", plant_category);
082            }
083    
084            public void testNumerizesAndDeNumerizesIrisDataSetExample2()
085                            throws Exception {
086                    DataSet ds = DataSetFactory.getIrisDataSet();
087                    Example first = ds.getExample(51);
088                    Numerizer n = new IrisDataSetNumerizer();
089                    Pair<List<Double>, List<Double>> io = n.numerize(first);
090    
091                    assertEquals(Arrays.asList(6.4, 3.2, 4.5, 1.5), io.getFirst());
092                    assertEquals(Arrays.asList(0.0, 1.0, 0.0), io.getSecond());
093    
094                    String plant_category = n.denumerize(Arrays.asList(0.0, 1.0, 0.0));
095                    assertEquals("versicolor", plant_category);
096            }
097    
098            public void testNumerizesAndDeNumerizesIrisDataSetExample3()
099                            throws Exception {
100                    DataSet ds = DataSetFactory.getIrisDataSet();
101                    Example first = ds.getExample(100);
102                    Numerizer n = new IrisDataSetNumerizer();
103                    Pair<List<Double>, List<Double>> io = n.numerize(first);
104    
105                    assertEquals(Arrays.asList(6.3, 3.3, 6.0, 2.5), io.getFirst());
106                    assertEquals(Arrays.asList(1.0, 0.0, 0.0), io.getSecond());
107    
108                    String plant_category = n.denumerize(Arrays.asList(1.0, 0.0, 0.0));
109                    assertEquals("virginica", plant_category);
110            }
111    
112    }