001    /*
002     * Created on Aug 6, 2005
003     *
004     */
005    package aima.learning.neural;
006    
007    import java.util.ArrayList;
008    import java.util.List;
009    import java.util.Arrays;
010    
011    import aima.learning.framework.Example;
012    import aima.util.Pair;
013    
014    /**
015     * @author Ravi Mohan
016     * 
017     */
018    
019    public class IrisDataSetNumerizer implements Numerizer {
020    
021            public Pair<List<Double>, List<Double>> numerize(Example e) {
022                    List<Double> input = new ArrayList<Double>();
023                    List<Double> desiredOutput = new ArrayList<Double>();
024    
025                    double sepal_length = e.getAttributeValueAsDouble("sepal_length");
026                    double sepal_width = e.getAttributeValueAsDouble("sepal_width");
027                    double petal_length = e.getAttributeValueAsDouble("petal_length");
028                    double petal_width = e.getAttributeValueAsDouble("petal_width");
029    
030                    input.add(sepal_length);
031                    input.add(sepal_width);
032                    input.add(petal_length);
033                    input.add(petal_width);
034    
035                    String plant_category_string = e
036                                    .getAttributeValueAsString("plant_category");
037    
038                    desiredOutput = convertCategoryToListOfDoubles(plant_category_string);
039    
040                    Pair<List<Double>, List<Double>> io = new Pair<List<Double>, List<Double>>(
041                                    input, desiredOutput);
042    
043                    return io;
044            }
045    
046            public String denumerize(List<Double> outputValue) {
047                    List<Double> rounded = new ArrayList<Double>();
048                    for (Double d : outputValue) {
049                            rounded.add(round(d));
050                    }
051                    if (rounded.equals(Arrays.asList(0.0, 0.0, 1.0))) {
052                            return "setosa";
053                    } else if (rounded.equals(Arrays.asList(0.0, 1.0, 0.0))) {
054                            return "versicolor";
055                    } else if (rounded.equals(Arrays.asList(1.0, 0.0, 0.0))) {
056                            return "virginica";
057                    } else {
058                            return "unknown";
059                    }
060            }
061    
062            private double round(Double d) {
063                    if (d < 0) {
064                            return 0.0;
065                    }
066                    if (d > 1) {
067                            return 1.0;
068                    } else {
069                            return Math.round(d);
070                    }
071            }
072    
073            private List<Double> convertCategoryToListOfDoubles(
074                            String plant_category_string) {
075                    if (plant_category_string.equals("setosa")) {
076                            return Arrays.asList(0.0, 0.0, 1.0);
077                    } else if (plant_category_string.equals("versicolor")) {
078                            return Arrays.asList(0.0, 1.0, 0.0);
079                    } else if (plant_category_string.equals("virginica")) {
080                            return Arrays.asList(1.0, 0.0, 0.0);
081                    } else {
082                            throw new RuntimeException("invalid plant category");
083                    }
084            }
085    
086    }