001    package aima.util;
002    
003    /**
004     * LU Decomposition.
005     * <P>
006     * For an m-by-n matrix A with m >= n, the LU decomposition is an m-by-n unit
007     * lower triangular matrix L, an n-by-n upper triangular matrix U, and a
008     * permutation vector piv of length m so that A(piv,:) = L*U. If m < n, then L
009     * is m-by-m and U is m-by-n.
010     * <P>
011     * The LU decompostion with pivoting always exists, even if the matrix is
012     * singular, so the constructor will never fail. The primary use of the LU
013     * decomposition is in the solution of square systems of simultaneous linear
014     * equations. This will fail if isNonsingular() returns false.
015     */
016    
017    public class LUDecomposition implements java.io.Serializable {
018    
019            /*
020             * ------------------------ Class variables ------------------------
021             */
022    
023            /**
024             * Array for internal storage of decomposition.
025             * 
026             * @serial internal array storage.
027             */
028            private final double[][] LU;
029    
030            /**
031             * Row and column dimensions, and pivot sign.
032             * 
033             * @serial column dimension.
034             * @serial row dimension.
035             * @serial pivot sign.
036             */
037            private final int m, n;
038    
039            private int pivsign;
040    
041            /**
042             * Internal storage of pivot vector.
043             * 
044             * @serial pivot vector.
045             */
046            private final int[] piv;
047    
048            /*
049             * ------------------------ Constructor ------------------------
050             */
051    
052            /**
053             * LU Decomposition, a structure to access L, U and piv.
054             * 
055             * @param A
056             *            Rectangular matrix
057             */
058    
059            public LUDecomposition(Matrix A) {
060    
061                    // Use a "left-looking", dot-product, Crout/Doolittle algorithm.
062    
063                    LU = A.getArrayCopy();
064                    m = A.getRowDimension();
065                    n = A.getColumnDimension();
066                    piv = new int[m];
067                    for (int i = 0; i < m; i++) {
068                            piv[i] = i;
069                    }
070                    pivsign = 1;
071                    double[] LUrowi;
072                    double[] LUcolj = new double[m];
073    
074                    // Outer loop.
075    
076                    for (int j = 0; j < n; j++) {
077    
078                            // Make a copy of the j-th column to localize references.
079    
080                            for (int i = 0; i < m; i++) {
081                                    LUcolj[i] = LU[i][j];
082                            }
083    
084                            // Apply previous transformations.
085    
086                            for (int i = 0; i < m; i++) {
087                                    LUrowi = LU[i];
088    
089                                    // Most of the time is spent in the following dot product.
090    
091                                    int kmax = Math.min(i, j);
092                                    double s = 0.0;
093                                    for (int k = 0; k < kmax; k++) {
094                                            s += LUrowi[k] * LUcolj[k];
095                                    }
096    
097                                    LUrowi[j] = LUcolj[i] -= s;
098                            }
099    
100                            // Find pivot and exchange if necessary.
101    
102                            int p = j;
103                            for (int i = j + 1; i < m; i++) {
104                                    if (Math.abs(LUcolj[i]) > Math.abs(LUcolj[p])) {
105                                            p = i;
106                                    }
107                            }
108                            if (p != j) {
109                                    for (int k = 0; k < n; k++) {
110                                            double t = LU[p][k];
111                                            LU[p][k] = LU[j][k];
112                                            LU[j][k] = t;
113                                    }
114                                    int k = piv[p];
115                                    piv[p] = piv[j];
116                                    piv[j] = k;
117                                    pivsign = -pivsign;
118                            }
119    
120                            // Compute multipliers.
121    
122                            if (j < m & LU[j][j] != 0.0) {
123                                    for (int i = j + 1; i < m; i++) {
124                                            LU[i][j] /= LU[j][j];
125                                    }
126                            }
127                    }
128            }
129    
130            /*
131             * ------------------------ Temporary, experimental code.
132             * ------------------------ *\
133             * 
134             * \** LU Decomposition, computed by Gaussian elimination. <P> This
135             * constructor computes L and U with the "daxpy"-based elimination algorithm
136             * used in LINPACK and MATLAB. In Java, we suspect the dot-product, Crout
137             * algorithm will be faster. We have temporarily included this constructor
138             * until timing experiments confirm this suspicion. <P> @param A Rectangular
139             * matrix @param linpackflag Use Gaussian elimination. Actual value ignored.
140             * @return Structure to access L, U and piv. \
141             * 
142             * public LUDecomposition (Matrix A, int linpackflag) { // Initialize. LU =
143             * A.getArrayCopy(); m = A.getRowDimension(); n = A.getColumnDimension();
144             * piv = new int[m]; for (int i = 0; i < m; i++) { piv[i] = i; } pivsign =
145             * 1; // Main loop. for (int k = 0; k < n; k++) { // Find pivot. int p = k;
146             * for (int i = k+1; i < m; i++) { if (Math.abs(LU[i][k]) >
147             * Math.abs(LU[p][k])) { p = i; } } // Exchange if necessary. if (p != k) {
148             * for (int j = 0; j < n; j++) { double t = LU[p][j]; LU[p][j] = LU[k][j];
149             * LU[k][j] = t; } int t = piv[p]; piv[p] = piv[k]; piv[k] = t; pivsign =
150             * -pivsign; } // Compute multipliers and eliminate k-th column. if
151             * (LU[k][k] != 0.0) { for (int i = k+1; i < m; i++) { LU[i][k] /= LU[k][k];
152             * for (int j = k+1; j < n; j++) { LU[i][j] -= LU[i][k]*LU[k][j]; } } } } } \*
153             * ------------------------ End of temporary code. ------------------------
154             */
155    
156            /*
157             * ------------------------ Public Methods ------------------------
158             */
159    
160            /**
161             * Is the matrix nonsingular?
162             * 
163             * @return true if U, and hence A, is nonsingular.
164             */
165    
166            public boolean isNonsingular() {
167                    for (int j = 0; j < n; j++) {
168                            if (LU[j][j] == 0)
169                                    return false;
170                    }
171                    return true;
172            }
173    
174            /**
175             * Return lower triangular factor
176             * 
177             * @return L
178             */
179    
180            public Matrix getL() {
181                    Matrix X = new Matrix(m, n);
182                    double[][] L = X.getArray();
183                    for (int i = 0; i < m; i++) {
184                            for (int j = 0; j < n; j++) {
185                                    if (i > j) {
186                                            L[i][j] = LU[i][j];
187                                    } else if (i == j) {
188                                            L[i][j] = 1.0;
189                                    } else {
190                                            L[i][j] = 0.0;
191                                    }
192                            }
193                    }
194                    return X;
195            }
196    
197            /**
198             * Return upper triangular factor
199             * 
200             * @return U
201             */
202    
203            public Matrix getU() {
204                    Matrix X = new Matrix(n, n);
205                    double[][] U = X.getArray();
206                    for (int i = 0; i < n; i++) {
207                            for (int j = 0; j < n; j++) {
208                                    if (i <= j) {
209                                            U[i][j] = LU[i][j];
210                                    } else {
211                                            U[i][j] = 0.0;
212                                    }
213                            }
214                    }
215                    return X;
216            }
217    
218            /**
219             * Return pivot permutation vector
220             * 
221             * @return piv
222             */
223    
224            public int[] getPivot() {
225                    int[] p = new int[m];
226                    for (int i = 0; i < m; i++) {
227                            p[i] = piv[i];
228                    }
229                    return p;
230            }
231    
232            /**
233             * Return pivot permutation vector as a one-dimensional double array
234             * 
235             * @return (double) piv
236             */
237    
238            public double[] getDoublePivot() {
239                    double[] vals = new double[m];
240                    for (int i = 0; i < m; i++) {
241                            vals[i] = piv[i];
242                    }
243                    return vals;
244            }
245    
246            /**
247             * Determinant
248             * 
249             * @return det(A)
250             * @exception IllegalArgumentException
251             *                Matrix must be square
252             */
253    
254            public double det() {
255                    if (m != n) {
256                            throw new IllegalArgumentException("Matrix must be square.");
257                    }
258                    double d = pivsign;
259                    for (int j = 0; j < n; j++) {
260                            d *= LU[j][j];
261                    }
262                    return d;
263            }
264    
265            /**
266             * Solve A*X = B
267             * 
268             * @param B
269             *            A Matrix with as many rows as A and any number of columns.
270             * @return X so that L*U*X = B(piv,:)
271             * @exception IllegalArgumentException
272             *                Matrix row dimensions must agree.
273             * @exception RuntimeException
274             *                Matrix is singular.
275             */
276    
277            public Matrix solve(Matrix B) {
278                    if (B.getRowDimension() != m) {
279                            throw new IllegalArgumentException(
280                                            "Matrix row dimensions must agree.");
281                    }
282                    if (!this.isNonsingular()) {
283                            throw new RuntimeException("Matrix is singular.");
284                    }
285    
286                    // Copy right hand side with pivoting
287                    int nx = B.getColumnDimension();
288                    Matrix Xmat = B.getMatrix(piv, 0, nx - 1);
289                    double[][] X = Xmat.getArray();
290    
291                    // Solve L*Y = B(piv,:)
292                    for (int k = 0; k < n; k++) {
293                            for (int i = k + 1; i < n; i++) {
294                                    for (int j = 0; j < nx; j++) {
295                                            X[i][j] -= X[k][j] * LU[i][k];
296                                    }
297                            }
298                    }
299                    // Solve U*X = Y;
300                    for (int k = n - 1; k >= 0; k--) {
301                            for (int j = 0; j < nx; j++) {
302                                    X[k][j] /= LU[k][k];
303                            }
304                            for (int i = 0; i < k; i++) {
305                                    for (int j = 0; j < nx; j++) {
306                                            X[i][j] -= X[k][j] * LU[i][k];
307                                    }
308                            }
309                    }
310                    return Xmat;
311            }
312    }