001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math.stat.regression; 018 019 import org.apache.commons.math.linear.Array2DRowRealMatrix; 020 import org.apache.commons.math.linear.LUDecompositionImpl; 021 import org.apache.commons.math.linear.QRDecomposition; 022 import org.apache.commons.math.linear.QRDecompositionImpl; 023 import org.apache.commons.math.linear.RealMatrix; 024 import org.apache.commons.math.linear.RealVector; 025 026 /** 027 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 028 * multiple linear regression model.</p> 029 * 030 * <p>OLS assumes the covariance matrix of the error to be diagonal and with 031 * equal variance.</p> 032 * <p> 033 * u ~ N(0, σ<sup>2</sup>I) 034 * </p> 035 * 036 * <p>The regression coefficients, b, satisfy the normal equations: 037 * <p> 038 * X<sup>T</sup> X b = X<sup>T</sup> y 039 * </p> 040 * 041 * <p>To solve the normal equations, this implementation uses QR decomposition 042 * of the X matrix. (See {@link QRDecompositionImpl} for details on the 043 * decomposition algorithm.) 044 * </p> 045 * <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/> 046 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/> 047 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/> 048 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/> 049 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/> 050 * R b = Q<sup>T</sup> y 051 * </p> 052 * Given Q and R, the last equation is solved by back-subsitution.</p> 053 * 054 * @version $Revision: 825925 $ $Date: 2009-10-16 11:11:47 -0400 (Fri, 16 Oct 2009) $ 055 * @since 2.0 056 */ 057 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { 058 059 /** Cached QR decomposition of X matrix */ 060 private QRDecomposition qr = null; 061 062 /** 063 * Loads model x and y sample data, overriding any previous sample. 064 * 065 * Computes and caches QR decomposition of the X matrix. 066 * @param y the [n,1] array representing the y sample 067 * @param x the [n,k] array representing the x sample 068 * @throws IllegalArgumentException if the x and y array data are not 069 * compatible for the regression 070 */ 071 public void newSampleData(double[] y, double[][] x) { 072 validateSampleData(x, y); 073 newYSampleData(y); 074 newXSampleData(x); 075 } 076 077 /** 078 * {@inheritDoc} 079 * 080 * Computes and caches QR decomposition of the X matrix 081 */ 082 @Override 083 public void newSampleData(double[] data, int nobs, int nvars) { 084 super.newSampleData(data, nobs, nvars); 085 qr = new QRDecompositionImpl(X); 086 } 087 088 /** 089 * <p>Compute the "hat" matrix. 090 * </p> 091 * <p>The hat matrix is defined in terms of the design matrix X 092 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup> 093 * </p> 094 * <p>The implementation here uses the QR decomposition to compute the 095 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the 096 * p-dimensional identity matrix augmented by 0's. This computational 097 * formula is from "The Hat Matrix in Regression and ANOVA", 098 * David C. Hoaglin and Roy E. Welsch, 099 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. 100 * 101 * @return the hat matrix 102 */ 103 public RealMatrix calculateHat() { 104 // Create augmented identity matrix 105 RealMatrix Q = qr.getQ(); 106 final int p = qr.getR().getColumnDimension(); 107 final int n = Q.getColumnDimension(); 108 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n); 109 double[][] augIData = augI.getDataRef(); 110 for (int i = 0; i < n; i++) { 111 for (int j =0; j < n; j++) { 112 if (i == j && i < p) { 113 augIData[i][j] = 1d; 114 } else { 115 augIData[i][j] = 0d; 116 } 117 } 118 } 119 120 // Compute and return Hat matrix 121 return Q.multiply(augI).multiply(Q.transpose()); 122 } 123 124 /** 125 * Loads new x sample data, overriding any previous sample 126 * 127 * @param x the [n,k] array representing the x sample 128 */ 129 @Override 130 protected void newXSampleData(double[][] x) { 131 this.X = new Array2DRowRealMatrix(x); 132 qr = new QRDecompositionImpl(X); 133 } 134 135 /** 136 * Calculates regression coefficients using OLS. 137 * 138 * @return beta 139 */ 140 @Override 141 protected RealVector calculateBeta() { 142 return qr.getSolver().solve(Y); 143 } 144 145 /** 146 * <p>Calculates the variance on the beta by OLS. 147 * </p> 148 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup> 149 * </p> 150 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup> 151 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of 152 * R included, where p = the length of the beta vector.</p> 153 * 154 * @return The beta variance 155 */ 156 @Override 157 protected RealMatrix calculateBetaVariance() { 158 int p = X.getColumnDimension(); 159 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1); 160 RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse(); 161 return Rinv.multiply(Rinv.transpose()); 162 } 163 164 165 /** 166 * <p>Calculates the variance on the Y by OLS. 167 * </p> 168 * <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k) 169 * </p> 170 * @return The Y variance 171 */ 172 @Override 173 protected double calculateYVariance() { 174 RealVector residuals = calculateResiduals(); 175 return residuals.dotProduct(residuals) / 176 (X.getRowDimension() - X.getColumnDimension()); 177 } 178 179 }