001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.stat.regression;
018    
019    import org.apache.commons.math.linear.Array2DRowRealMatrix;
020    import org.apache.commons.math.linear.LUDecompositionImpl;
021    import org.apache.commons.math.linear.QRDecomposition;
022    import org.apache.commons.math.linear.QRDecompositionImpl;
023    import org.apache.commons.math.linear.RealMatrix;
024    import org.apache.commons.math.linear.RealVector;
025    
026    /**
027     * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
028     * multiple linear regression model.</p>
029     *
030     * <p>OLS assumes the covariance matrix of the error to be diagonal and with
031     * equal variance.</p>
032     * <p>
033     * u ~ N(0, &sigma;<sup>2</sup>I)
034     * </p>
035     *
036     * <p>The regression coefficients, b, satisfy the normal equations:
037     * <p>
038     * X<sup>T</sup> X b = X<sup>T</sup> y
039     * </p>
040     *
041     * <p>To solve the normal equations, this implementation uses QR decomposition
042     * of the X matrix. (See {@link QRDecompositionImpl} for details on the
043     * decomposition algorithm.)
044     * </p>
045     * <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/>
046     * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/>
047     * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
048     * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
049     * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
050     * R b = Q<sup>T</sup> y
051     * </p>
052     * Given Q and R, the last equation is solved by back-subsitution.</p>
053     *
054     * @version $Revision: 825925 $ $Date: 2009-10-16 11:11:47 -0400 (Fri, 16 Oct 2009) $
055     * @since 2.0
056     */
057    public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
058    
059        /** Cached QR decomposition of X matrix */
060        private QRDecomposition qr = null;
061    
062        /**
063         * Loads model x and y sample data, overriding any previous sample.
064         *
065         * Computes and caches QR decomposition of the X matrix.
066         * @param y the [n,1] array representing the y sample
067         * @param x the [n,k] array representing the x sample
068         * @throws IllegalArgumentException if the x and y array data are not
069         *             compatible for the regression
070         */
071        public void newSampleData(double[] y, double[][] x) {
072            validateSampleData(x, y);
073            newYSampleData(y);
074            newXSampleData(x);
075        }
076    
077        /**
078         * {@inheritDoc}
079         *
080         * Computes and caches QR decomposition of the X matrix
081         */
082        @Override
083        public void newSampleData(double[] data, int nobs, int nvars) {
084            super.newSampleData(data, nobs, nvars);
085            qr = new QRDecompositionImpl(X);
086        }
087    
088        /**
089         * <p>Compute the "hat" matrix.
090         * </p>
091         * <p>The hat matrix is defined in terms of the design matrix X
092         *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
093         * </p>
094         * <p>The implementation here uses the QR decomposition to compute the
095         * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
096         * p-dimensional identity matrix augmented by 0's.  This computational
097         * formula is from "The Hat Matrix in Regression and ANOVA",
098         * David C. Hoaglin and Roy E. Welsch,
099         * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
100         *
101         * @return the hat matrix
102         */
103        public RealMatrix calculateHat() {
104            // Create augmented identity matrix
105            RealMatrix Q = qr.getQ();
106            final int p = qr.getR().getColumnDimension();
107            final int n = Q.getColumnDimension();
108            Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
109            double[][] augIData = augI.getDataRef();
110            for (int i = 0; i < n; i++) {
111                for (int j =0; j < n; j++) {
112                    if (i == j && i < p) {
113                        augIData[i][j] = 1d;
114                    } else {
115                        augIData[i][j] = 0d;
116                    }
117                }
118            }
119    
120            // Compute and return Hat matrix
121            return Q.multiply(augI).multiply(Q.transpose());
122        }
123    
124        /**
125         * Loads new x sample data, overriding any previous sample
126         *
127         * @param x the [n,k] array representing the x sample
128         */
129        @Override
130        protected void newXSampleData(double[][] x) {
131            this.X = new Array2DRowRealMatrix(x);
132            qr = new QRDecompositionImpl(X);
133        }
134    
135        /**
136         * Calculates regression coefficients using OLS.
137         *
138         * @return beta
139         */
140        @Override
141        protected RealVector calculateBeta() {
142            return qr.getSolver().solve(Y);
143        }
144    
145        /**
146         * <p>Calculates the variance on the beta by OLS.
147         * </p>
148         * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
149         * </p>
150         * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
151         * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
152         * R included, where p = the length of the beta vector.</p>
153         *
154         * @return The beta variance
155         */
156        @Override
157        protected RealMatrix calculateBetaVariance() {
158            int p = X.getColumnDimension();
159            RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
160            RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
161            return Rinv.multiply(Rinv.transpose());
162        }
163    
164    
165        /**
166         * <p>Calculates the variance on the Y by OLS.
167         * </p>
168         * <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k)
169         * </p>
170         * @return The Y variance
171         */
172        @Override
173        protected double calculateYVariance() {
174            RealVector residuals = calculateResiduals();
175            return residuals.dotProduct(residuals) /
176                   (X.getRowDimension() - X.getColumnDimension());
177        }
178    
179    }