001 /* 002 * Created on Aug 6, 2005 003 * 004 */ 005 package aima.learning.neural; 006 007 import java.util.ArrayList; 008 import java.util.List; 009 import java.util.Arrays; 010 011 import aima.learning.framework.Example; 012 import aima.util.Pair; 013 014 /** 015 * @author Ravi Mohan 016 * 017 */ 018 019 public class IrisDataSetNumerizer implements Numerizer { 020 021 public Pair<List<Double>, List<Double>> numerize(Example e) { 022 List<Double> input = new ArrayList<Double>(); 023 List<Double> desiredOutput = new ArrayList<Double>(); 024 025 double sepal_length = e.getAttributeValueAsDouble("sepal_length"); 026 double sepal_width = e.getAttributeValueAsDouble("sepal_width"); 027 double petal_length = e.getAttributeValueAsDouble("petal_length"); 028 double petal_width = e.getAttributeValueAsDouble("petal_width"); 029 030 input.add(sepal_length); 031 input.add(sepal_width); 032 input.add(petal_length); 033 input.add(petal_width); 034 035 String plant_category_string = e 036 .getAttributeValueAsString("plant_category"); 037 038 desiredOutput = convertCategoryToListOfDoubles(plant_category_string); 039 040 Pair<List<Double>, List<Double>> io = new Pair<List<Double>, List<Double>>( 041 input, desiredOutput); 042 043 return io; 044 } 045 046 public String denumerize(List<Double> outputValue) { 047 List<Double> rounded = new ArrayList<Double>(); 048 for (Double d : outputValue) { 049 rounded.add(round(d)); 050 } 051 if (rounded.equals(Arrays.asList(0.0, 0.0, 1.0))) { 052 return "setosa"; 053 } else if (rounded.equals(Arrays.asList(0.0, 1.0, 0.0))) { 054 return "versicolor"; 055 } else if (rounded.equals(Arrays.asList(1.0, 0.0, 0.0))) { 056 return "virginica"; 057 } else { 058 return "unknown"; 059 } 060 } 061 062 private double round(Double d) { 063 if (d < 0) { 064 return 0.0; 065 } 066 if (d > 1) { 067 return 1.0; 068 } else { 069 return Math.round(d); 070 } 071 } 072 073 private List<Double> convertCategoryToListOfDoubles( 074 String plant_category_string) { 075 if (plant_category_string.equals("setosa")) { 076 return Arrays.asList(0.0, 0.0, 1.0); 077 } else if (plant_category_string.equals("versicolor")) { 078 return Arrays.asList(0.0, 1.0, 0.0); 079 } else if (plant_category_string.equals("virginica")) { 080 return Arrays.asList(1.0, 0.0, 0.0); 081 } else { 082 throw new RuntimeException("invalid plant category"); 083 } 084 } 085 086 }