001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import org.apache.commons.math.MathRuntimeException;
021    
022    /**
023     * Calculates the LUP-decomposition of a square matrix.
024     * <p>The LUP-decomposition of a matrix A consists of three matrices
025     * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
026     * upper triangular and P is a permutation matrix. All matrices are
027     * m&times;m.</p>
028     * <p>As shown by the presence of the P matrix, this decomposition is
029     * implemented using partial pivoting.</p>
030     *
031     * @version $Revision: 885278 $ $Date: 2009-11-29 16:47:51 -0500 (Sun, 29 Nov 2009) $
032     * @since 2.0
033     */
034    public class LUDecompositionImpl implements LUDecomposition {
035    
036        /** Default bound to determine effective singularity in LU decomposition */
037        private static final double DEFAULT_TOO_SMALL = 10E-12;
038    
039        /** Message for vector length mismatch. */
040        private static final String VECTOR_LENGTH_MISMATCH_MESSAGE =
041            "vector length mismatch: got {0} but expected {1}";
042    
043        /** Entries of LU decomposition. */
044        private double lu[][];
045    
046        /** Pivot permutation associated with LU decomposition */
047        private int[] pivot;
048    
049        /** Parity of the permutation associated with the LU decomposition */
050        private boolean even;
051    
052        /** Singularity indicator. */
053        private boolean singular;
054    
055        /** Cached value of L. */
056        private RealMatrix cachedL;
057    
058        /** Cached value of U. */
059        private RealMatrix cachedU;
060    
061        /** Cached value of P. */
062        private RealMatrix cachedP;
063    
064        /**
065         * Calculates the LU-decomposition of the given matrix.
066         * @param matrix The matrix to decompose.
067         * @exception InvalidMatrixException if matrix is not square
068         */
069        public LUDecompositionImpl(RealMatrix matrix)
070            throws InvalidMatrixException {
071            this(matrix, DEFAULT_TOO_SMALL);
072        }
073    
074        /**
075         * Calculates the LU-decomposition of the given matrix.
076         * @param matrix The matrix to decompose.
077         * @param singularityThreshold threshold (based on partial row norm)
078         * under which a matrix is considered singular
079         * @exception NonSquareMatrixException if matrix is not square
080         */
081        public LUDecompositionImpl(RealMatrix matrix, double singularityThreshold)
082            throws NonSquareMatrixException {
083    
084            if (!matrix.isSquare()) {
085                throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
086            }
087    
088            final int m = matrix.getColumnDimension();
089            lu = matrix.getData();
090            pivot = new int[m];
091            cachedL = null;
092            cachedU = null;
093            cachedP = null;
094    
095            // Initialize permutation array and parity
096            for (int row = 0; row < m; row++) {
097                pivot[row] = row;
098            }
099            even     = true;
100            singular = false;
101    
102            // Loop over columns
103            for (int col = 0; col < m; col++) {
104    
105                double sum = 0;
106    
107                // upper
108                for (int row = 0; row < col; row++) {
109                    final double[] luRow = lu[row];
110                    sum = luRow[col];
111                    for (int i = 0; i < row; i++) {
112                        sum -= luRow[i] * lu[i][col];
113                    }
114                    luRow[col] = sum;
115                }
116    
117                // lower
118                int max = col; // permutation row
119                double largest = Double.NEGATIVE_INFINITY;
120                for (int row = col; row < m; row++) {
121                    final double[] luRow = lu[row];
122                    sum = luRow[col];
123                    for (int i = 0; i < col; i++) {
124                        sum -= luRow[i] * lu[i][col];
125                    }
126                    luRow[col] = sum;
127    
128                    // maintain best permutation choice
129                    if (Math.abs(sum) > largest) {
130                        largest = Math.abs(sum);
131                        max = row;
132                    }
133                }
134    
135                // Singularity check
136                if (Math.abs(lu[max][col]) < singularityThreshold) {
137                    singular = true;
138                    return;
139                }
140    
141                // Pivot if necessary
142                if (max != col) {
143                    double tmp = 0;
144                    final double[] luMax = lu[max];
145                    final double[] luCol = lu[col];
146                    for (int i = 0; i < m; i++) {
147                        tmp = luMax[i];
148                        luMax[i] = luCol[i];
149                        luCol[i] = tmp;
150                    }
151                    int temp = pivot[max];
152                    pivot[max] = pivot[col];
153                    pivot[col] = temp;
154                    even = !even;
155                }
156    
157                // Divide the lower elements by the "winning" diagonal elt.
158                final double luDiag = lu[col][col];
159                for (int row = col + 1; row < m; row++) {
160                    lu[row][col] /= luDiag;
161                }
162            }
163    
164        }
165    
166        /** {@inheritDoc} */
167        public RealMatrix getL() {
168            if ((cachedL == null) && !singular) {
169                final int m = pivot.length;
170                cachedL = MatrixUtils.createRealMatrix(m, m);
171                for (int i = 0; i < m; ++i) {
172                    final double[] luI = lu[i];
173                    for (int j = 0; j < i; ++j) {
174                        cachedL.setEntry(i, j, luI[j]);
175                    }
176                    cachedL.setEntry(i, i, 1.0);
177                }
178            }
179            return cachedL;
180        }
181    
182        /** {@inheritDoc} */
183        public RealMatrix getU() {
184            if ((cachedU == null) && !singular) {
185                final int m = pivot.length;
186                cachedU = MatrixUtils.createRealMatrix(m, m);
187                for (int i = 0; i < m; ++i) {
188                    final double[] luI = lu[i];
189                    for (int j = i; j < m; ++j) {
190                        cachedU.setEntry(i, j, luI[j]);
191                    }
192                }
193            }
194            return cachedU;
195        }
196    
197        /** {@inheritDoc} */
198        public RealMatrix getP() {
199            if ((cachedP == null) && !singular) {
200                final int m = pivot.length;
201                cachedP = MatrixUtils.createRealMatrix(m, m);
202                for (int i = 0; i < m; ++i) {
203                    cachedP.setEntry(i, pivot[i], 1.0);
204                }
205            }
206            return cachedP;
207        }
208    
209        /** {@inheritDoc} */
210        public int[] getPivot() {
211            return pivot.clone();
212        }
213    
214        /** {@inheritDoc} */
215        public double getDeterminant() {
216            if (singular) {
217                return 0;
218            } else {
219                final int m = pivot.length;
220                double determinant = even ? 1 : -1;
221                for (int i = 0; i < m; i++) {
222                    determinant *= lu[i][i];
223                }
224                return determinant;
225            }
226        }
227    
228        /** {@inheritDoc} */
229        public DecompositionSolver getSolver() {
230            return new Solver(lu, pivot, singular);
231        }
232    
233        /** Specialized solver. */
234        private static class Solver implements DecompositionSolver {
235    
236            /** Entries of LU decomposition. */
237            private final double lu[][];
238    
239            /** Pivot permutation associated with LU decomposition. */
240            private final int[] pivot;
241    
242            /** Singularity indicator. */
243            private final boolean singular;
244    
245            /**
246             * Build a solver from decomposed matrix.
247             * @param lu entries of LU decomposition
248             * @param pivot pivot permutation associated with LU decomposition
249             * @param singular singularity indicator
250             */
251            private Solver(final double[][] lu, final int[] pivot, final boolean singular) {
252                this.lu       = lu;
253                this.pivot    = pivot;
254                this.singular = singular;
255            }
256    
257            /** {@inheritDoc} */
258            public boolean isNonSingular() {
259                return !singular;
260            }
261    
262            /** {@inheritDoc} */
263            public double[] solve(double[] b)
264                throws IllegalArgumentException, InvalidMatrixException {
265    
266                final int m = pivot.length;
267                if (b.length != m) {
268                    throw MathRuntimeException.createIllegalArgumentException(
269                            VECTOR_LENGTH_MISMATCH_MESSAGE, b.length, m);
270                }
271                if (singular) {
272                    throw new SingularMatrixException();
273                }
274    
275                final double[] bp = new double[m];
276    
277                // Apply permutations to b
278                for (int row = 0; row < m; row++) {
279                    bp[row] = b[pivot[row]];
280                }
281    
282                // Solve LY = b
283                for (int col = 0; col < m; col++) {
284                    final double bpCol = bp[col];
285                    for (int i = col + 1; i < m; i++) {
286                        bp[i] -= bpCol * lu[i][col];
287                    }
288                }
289    
290                // Solve UX = Y
291                for (int col = m - 1; col >= 0; col--) {
292                    bp[col] /= lu[col][col];
293                    final double bpCol = bp[col];
294                    for (int i = 0; i < col; i++) {
295                        bp[i] -= bpCol * lu[i][col];
296                    }
297                }
298    
299                return bp;
300    
301            }
302    
303            /** {@inheritDoc} */
304            public RealVector solve(RealVector b)
305                throws IllegalArgumentException, InvalidMatrixException {
306                try {
307                    return solve((ArrayRealVector) b);
308                } catch (ClassCastException cce) {
309    
310                    final int m = pivot.length;
311                    if (b.getDimension() != m) {
312                        throw MathRuntimeException.createIllegalArgumentException(
313                                VECTOR_LENGTH_MISMATCH_MESSAGE, b.getDimension(), m);
314                    }
315                    if (singular) {
316                        throw new SingularMatrixException();
317                    }
318    
319                    final double[] bp = new double[m];
320    
321                    // Apply permutations to b
322                    for (int row = 0; row < m; row++) {
323                        bp[row] = b.getEntry(pivot[row]);
324                    }
325    
326                    // Solve LY = b
327                    for (int col = 0; col < m; col++) {
328                        final double bpCol = bp[col];
329                        for (int i = col + 1; i < m; i++) {
330                            bp[i] -= bpCol * lu[i][col];
331                        }
332                    }
333    
334                    // Solve UX = Y
335                    for (int col = m - 1; col >= 0; col--) {
336                        bp[col] /= lu[col][col];
337                        final double bpCol = bp[col];
338                        for (int i = 0; i < col; i++) {
339                            bp[i] -= bpCol * lu[i][col];
340                        }
341                    }
342    
343                    return new ArrayRealVector(bp, false);
344    
345                }
346            }
347    
348            /** Solve the linear equation A &times; X = B.
349             * <p>The A matrix is implicit here. It is </p>
350             * @param b right-hand side of the equation A &times; X = B
351             * @return a vector X such that A &times; X = B
352             * @exception IllegalArgumentException if matrices dimensions don't match
353             * @exception InvalidMatrixException if decomposed matrix is singular
354             */
355            public ArrayRealVector solve(ArrayRealVector b)
356                throws IllegalArgumentException, InvalidMatrixException {
357                return new ArrayRealVector(solve(b.getDataRef()), false);
358            }
359    
360            /** {@inheritDoc} */
361            public RealMatrix solve(RealMatrix b)
362                throws IllegalArgumentException, InvalidMatrixException {
363    
364                final int m = pivot.length;
365                if (b.getRowDimension() != m) {
366                    throw MathRuntimeException.createIllegalArgumentException(
367                            "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
368                            b.getRowDimension(), b.getColumnDimension(), m, "n");
369                }
370                if (singular) {
371                    throw new SingularMatrixException();
372                }
373    
374                final int nColB = b.getColumnDimension();
375    
376                // Apply permutations to b
377                final double[][] bp = new double[m][nColB];
378                for (int row = 0; row < m; row++) {
379                    final double[] bpRow = bp[row];
380                    final int pRow = pivot[row];
381                    for (int col = 0; col < nColB; col++) {
382                        bpRow[col] = b.getEntry(pRow, col);
383                    }
384                }
385    
386                // Solve LY = b
387                for (int col = 0; col < m; col++) {
388                    final double[] bpCol = bp[col];
389                    for (int i = col + 1; i < m; i++) {
390                        final double[] bpI = bp[i];
391                        final double luICol = lu[i][col];
392                        for (int j = 0; j < nColB; j++) {
393                            bpI[j] -= bpCol[j] * luICol;
394                        }
395                    }
396                }
397    
398                // Solve UX = Y
399                for (int col = m - 1; col >= 0; col--) {
400                    final double[] bpCol = bp[col];
401                    final double luDiag = lu[col][col];
402                    for (int j = 0; j < nColB; j++) {
403                        bpCol[j] /= luDiag;
404                    }
405                    for (int i = 0; i < col; i++) {
406                        final double[] bpI = bp[i];
407                        final double luICol = lu[i][col];
408                        for (int j = 0; j < nColB; j++) {
409                            bpI[j] -= bpCol[j] * luICol;
410                        }
411                    }
412                }
413    
414                return new Array2DRowRealMatrix(bp, false);
415    
416            }
417    
418            /** {@inheritDoc} */
419            public RealMatrix getInverse() throws InvalidMatrixException {
420                return solve(MatrixUtils.createRealIdentityMatrix(pivot.length));
421            }
422    
423        }
424    
425    }