001 /* 002 * Created on Apr 9, 2005 003 * 004 */ 005 package aima.test.learningtest; 006 007 import java.io.FileNotFoundException; 008 import java.util.Arrays; 009 import java.util.List; 010 011 import junit.framework.TestCase; 012 import aima.learning.framework.DataSet; 013 import aima.learning.framework.DataSetFactory; 014 import aima.learning.framework.DataSetSpecification; 015 import aima.learning.framework.Example; 016 import aima.learning.neural.IrisDataSetNumerizer; 017 import aima.learning.neural.Numerizer; 018 import aima.util.Pair; 019 020 /** 021 * @author Ravi Mohan 022 * 023 */ 024 public class DataSetTest extends TestCase { 025 private static final String NO = "No"; 026 027 private static final String YES = "Yes"; 028 029 DataSetSpecification spec; 030 031 public void testLoadsDatasetFile() throws Exception { 032 033 DataSet ds = DataSetFactory.getRestaurantDataSet(); 034 assertEquals(12, ds.size()); 035 036 Example first = ds.getExample(0); 037 assertEquals(YES, first.getAttributeValueAsString("alternate")); 038 assertEquals("$$$", first.getAttributeValueAsString("price")); 039 assertEquals("0-10", first.getAttributeValueAsString("wait_estimate")); 040 assertEquals(YES, first.getAttributeValueAsString("will_wait")); 041 assertEquals(YES, first.targetValue()); 042 } 043 044 public void testThrowsExceptionForNonExistentFile() 045 throws FileNotFoundException { 046 try { 047 DataSet ds = new DataSetFactory().fromFile("nonexistent", null, 048 null); 049 fail("should have thrown Exception"); 050 } catch (Exception ex) { 051 052 } 053 054 } 055 056 public void testLoadsIrisDataSetWithNumericAndStringAttributes() 057 throws Exception { 058 DataSet ds = DataSetFactory.getIrisDataSet(); 059 Example first = ds.getExample(0); 060 assertEquals("5.1", first.getAttributeValueAsString("sepal_length")); 061 } 062 063 public void testNonDestructiveRemoveExample() throws Exception { 064 DataSet ds1 = DataSetFactory.getRestaurantDataSet(); 065 DataSet ds2 = ds1.removeExample(ds1.getExample(0)); 066 assertEquals(12, ds1.size()); 067 assertEquals(11, ds2.size()); 068 } 069 070 public void testNumerizesAndDeNumerizesIrisDataSetExample1() 071 throws Exception { 072 DataSet ds = DataSetFactory.getIrisDataSet(); 073 Example first = ds.getExample(0); 074 Numerizer n = new IrisDataSetNumerizer(); 075 Pair<List<Double>, List<Double>> io = n.numerize(first); 076 077 assertEquals(Arrays.asList(5.1, 3.5, 1.4, 0.2), io.getFirst()); 078 assertEquals(Arrays.asList(0.0, 0.0, 1.0), io.getSecond()); 079 080 String plant_category = n.denumerize(Arrays.asList(0.0, 0.0, 1.0)); 081 assertEquals("setosa", plant_category); 082 } 083 084 public void testNumerizesAndDeNumerizesIrisDataSetExample2() 085 throws Exception { 086 DataSet ds = DataSetFactory.getIrisDataSet(); 087 Example first = ds.getExample(51); 088 Numerizer n = new IrisDataSetNumerizer(); 089 Pair<List<Double>, List<Double>> io = n.numerize(first); 090 091 assertEquals(Arrays.asList(6.4, 3.2, 4.5, 1.5), io.getFirst()); 092 assertEquals(Arrays.asList(0.0, 1.0, 0.0), io.getSecond()); 093 094 String plant_category = n.denumerize(Arrays.asList(0.0, 1.0, 0.0)); 095 assertEquals("versicolor", plant_category); 096 } 097 098 public void testNumerizesAndDeNumerizesIrisDataSetExample3() 099 throws Exception { 100 DataSet ds = DataSetFactory.getIrisDataSet(); 101 Example first = ds.getExample(100); 102 Numerizer n = new IrisDataSetNumerizer(); 103 Pair<List<Double>, List<Double>> io = n.numerize(first); 104 105 assertEquals(Arrays.asList(6.3, 3.3, 6.0, 2.5), io.getFirst()); 106 assertEquals(Arrays.asList(1.0, 0.0, 0.0), io.getSecond()); 107 108 String plant_category = n.denumerize(Arrays.asList(1.0, 0.0, 0.0)); 109 assertEquals("virginica", plant_category); 110 } 111 112 }