Classes | |
struct | depth_lhs< glue_type, T1 > |
Template metaprogram depth_lhs calculates the number of Glue<Tx,Ty, glue_type> instances on the left hand side argument of Glue<Tx,Ty, glue_type> i.e. it recursively expands each Tx, until the type of Tx is not "Glue<..,.., glue_type>" (i.e the "glue_type" changes). More... | |
struct | depth_lhs< glue_type, Glue< T1, T2, glue_type > > |
struct | glue_times_redirect< N > |
struct | glue_times_redirect< 3 > |
struct | glue_times_redirect< 4 > |
class | glue_times |
Class which implements the immediate multiplication of two or more matrices. More... | |
class | glue_times_diag |
Functions | |
template<typename T1 , typename T2 > | |
static void | glue_times_redirect::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X) |
template<typename T1 , typename T2 , typename T3 > | |
static void | glue_times_redirect< 3 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< T1, T2, glue_times >, T3, glue_times > &X) |
template<typename T1 , typename T2 , typename T3 , typename T4 > | |
static void | glue_times_redirect< 4 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > &X) |
template<typename T1 , typename T2 > | |
static void | glue_times::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X) |
template<typename T1 > | |
static void | glue_times::apply_inplace (Mat< typename T1::elem_type > &out, const T1 &X) |
template<typename T1 , typename T2 > | |
static arma_hot void | glue_times::apply_inplace_plus (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X, const s32 sign) |
template<typename eT1 , typename eT2 > | |
static void | glue_times::apply_mixed (Mat< typename promote_type< eT1, eT2 >::result > &out, const Mat< eT1 > &X, const Mat< eT2 > &Y) |
matrix multiplication with different element types | |
template<typename eT > | |
static arma_inline u32 | glue_times::mul_storage_cost (const Mat< eT > &A, const Mat< eT > &B, const bool do_trans_A, const bool do_trans_B) |
template<typename eT > | |
static arma_hot void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_scalar_times) |
template<typename eT > | |
static void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_scalar_times) |
template<typename eT > | |
static void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const Mat< eT > &D, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_trans_D, const bool do_scalar_times) |
template<typename T1 , typename T2 > | |
static arma_hot void | glue_times_diag::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times_diag > &X) |
void glue_times_redirect< N >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 26 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
Referenced by glue_times_redirect< 4 >::apply(), and glue_times_redirect< 3 >::apply().
00027 { 00028 arma_extra_debug_sigprint(); 00029 00030 typedef typename T1::elem_type eT; 00031 00032 const partial_unwrap_check<T1> tmp1(X.A, out); 00033 const partial_unwrap_check<T2> tmp2(X.B, out); 00034 00035 const Mat<eT>& A = tmp1.M; 00036 const Mat<eT>& B = tmp2.M; 00037 00038 const bool do_trans_A = tmp1.do_trans; 00039 const bool do_trans_B = tmp2.do_trans; 00040 00041 const bool use_alpha = tmp1.do_times | tmp2.do_times; 00042 const eT alpha = use_alpha ? (tmp1.val * tmp2.val) : eT(0); 00043 00044 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha); 00045 }
void glue_times_redirect< 3 >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< Glue< T1, T2, glue_times >, T3, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 52 of file glue_times_meat.hpp.
References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
00053 { 00054 arma_extra_debug_sigprint(); 00055 00056 typedef typename T1::elem_type eT; 00057 00058 // there is exactly 3 objects 00059 // hence we can safely expand X as X.A.A, X.A.B and X.B 00060 00061 const partial_unwrap_check<T1> tmp1(X.A.A, out); 00062 const partial_unwrap_check<T2> tmp2(X.A.B, out); 00063 const partial_unwrap_check<T3> tmp3(X.B, out); 00064 00065 const Mat<eT>& A = tmp1.M; 00066 const Mat<eT>& B = tmp2.M; 00067 const Mat<eT>& C = tmp3.M; 00068 00069 const bool do_trans_A = tmp1.do_trans; 00070 const bool do_trans_B = tmp2.do_trans; 00071 const bool do_trans_C = tmp3.do_trans; 00072 00073 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times; 00074 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val) : eT(0); 00075 00076 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); 00077 }
void glue_times_redirect< 4 >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 84 of file glue_times_meat.hpp.
References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
00085 { 00086 arma_extra_debug_sigprint(); 00087 00088 typedef typename T1::elem_type eT; 00089 00090 // there is exactly 4 objects 00091 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B 00092 00093 const partial_unwrap_check<T1> tmp1(X.A.A.A, out); 00094 const partial_unwrap_check<T2> tmp2(X.A.A.B, out); 00095 const partial_unwrap_check<T3> tmp3(X.A.B, out); 00096 const partial_unwrap_check<T4> tmp4(X.B, out); 00097 00098 const Mat<eT>& A = tmp1.M; 00099 const Mat<eT>& B = tmp2.M; 00100 const Mat<eT>& C = tmp3.M; 00101 const Mat<eT>& D = tmp4.M; 00102 00103 const bool do_trans_A = tmp1.do_trans; 00104 const bool do_trans_B = tmp2.do_trans; 00105 const bool do_trans_C = tmp3.do_trans; 00106 const bool do_trans_D = tmp4.do_trans; 00107 00108 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times | tmp4.do_times; 00109 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val * tmp4.val) : eT(0); 00110 00111 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha); 00112 }
void glue_times::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 119 of file glue_times_meat.hpp.
Referenced by apply(), apply_inplace(), apply_inplace_plus(), and apply_mixed().
00120 { 00121 arma_extra_debug_sigprint(); 00122 00123 typedef typename T1::elem_type eT; 00124 00125 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num; 00126 00127 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); 00128 00129 glue_times_redirect<N_mat>::apply(out, X); 00130 }
void glue_times::apply_inplace | ( | Mat< typename T1::elem_type > & | out, | |
const T1 & | X | |||
) | [inline, static, inherited] |
Definition at line 137 of file glue_times_meat.hpp.
References apply(), Mat< eT >::at(), Mat< eT >::colptr(), unwrap_check< T1 >::M, podarray< eT >::memptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by Mat< eT >::operator*=().
00138 { 00139 arma_extra_debug_sigprint(); 00140 00141 typedef typename T1::elem_type eT; 00142 00143 const unwrap_check<T1> tmp(X, out); 00144 const Mat<eT>& B = tmp.M; 00145 00146 arma_debug_assert_mul_size(out, B, "matrix multiply"); 00147 00148 if(out.n_cols == B.n_cols) 00149 { 00150 podarray<eT> tmp(out.n_cols); 00151 eT* tmp_rowdata = tmp.memptr(); 00152 00153 for(u32 out_row=0; out_row < out.n_rows; ++out_row) 00154 { 00155 for(u32 out_col=0; out_col < out.n_cols; ++out_col) 00156 { 00157 tmp_rowdata[out_col] = out.at(out_row,out_col); 00158 } 00159 00160 for(u32 B_col=0; B_col < B.n_cols; ++B_col) 00161 { 00162 const eT* B_coldata = B.colptr(B_col); 00163 00164 eT val = eT(0); 00165 for(u32 i=0; i < B.n_rows; ++i) 00166 { 00167 val += tmp_rowdata[i] * B_coldata[i]; 00168 } 00169 00170 out.at(out_row,B_col) = val; 00171 } 00172 } 00173 00174 } 00175 else 00176 { 00177 const Mat<eT> tmp(out); 00178 glue_times::apply(out, tmp, B, eT(1), false, false, false); 00179 } 00180 00181 }
arma_hot void glue_times::apply_inplace_plus | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X, | |||
const s32 | sign | |||
) | [inline, static, inherited] |
Definition at line 189 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, apply(), arma_assert_same_size(), Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and partial_unwrap_check< T1 >::val.
Referenced by Mat< eT >::operator+=(), and Mat< eT >::operator-=().
00190 { 00191 arma_extra_debug_sigprint(); 00192 00193 typedef typename T1::elem_type eT; 00194 00195 const partial_unwrap_check<T1> tmp1(X.A, out); 00196 const partial_unwrap_check<T2> tmp2(X.B, out); 00197 00198 const Mat<eT>& A = tmp1.M; 00199 const Mat<eT>& B = tmp2.M; 00200 const eT alpha = tmp1.val * tmp2.val * ( (sign > s32(0)) ? eT(1) : eT(-1) ); 00201 00202 const bool do_trans_A = tmp1.do_trans; 00203 const bool do_trans_B = tmp2.do_trans; 00204 const bool use_alpha = tmp1.do_times | tmp2.do_times | (sign < s32(0)); 00205 00206 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply"); 00207 00208 const u32 result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; 00209 const u32 result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; 00210 00211 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "matrix addition"); 00212 00213 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) 00214 { 00215 if(A.n_rows == 1) 00216 { 00217 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00218 } 00219 if(B.n_cols == 1) 00220 { 00221 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00222 } 00223 else 00224 { 00225 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1)); 00226 } 00227 } 00228 else 00229 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) 00230 { 00231 if(A.n_rows == 1) 00232 { 00233 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00234 } 00235 if(B.n_cols == 1) 00236 { 00237 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00238 } 00239 else 00240 { 00241 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1)); 00242 } 00243 } 00244 else 00245 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) 00246 { 00247 if(A.n_cols == 1) 00248 { 00249 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00250 } 00251 if(B.n_cols == 1) 00252 { 00253 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00254 } 00255 else 00256 { 00257 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1)); 00258 } 00259 } 00260 else 00261 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) 00262 { 00263 if(A.n_cols == 1) 00264 { 00265 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00266 } 00267 if(B.n_cols == 1) 00268 { 00269 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00270 } 00271 else 00272 { 00273 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1)); 00274 } 00275 } 00276 else 00277 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) 00278 { 00279 if(A.n_rows == 1) 00280 { 00281 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00282 } 00283 if(B.n_rows == 1) 00284 { 00285 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00286 } 00287 else 00288 { 00289 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1)); 00290 } 00291 } 00292 else 00293 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) 00294 { 00295 if(A.n_rows == 1) 00296 { 00297 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00298 } 00299 if(B.n_rows == 1) 00300 { 00301 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00302 } 00303 else 00304 { 00305 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1)); 00306 } 00307 } 00308 else 00309 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) 00310 { 00311 if(A.n_cols == 1) 00312 { 00313 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00314 } 00315 if(B.n_rows == 1) 00316 { 00317 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00318 } 00319 else 00320 { 00321 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1)); 00322 } 00323 } 00324 else 00325 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) 00326 { 00327 if(A.n_cols == 1) 00328 { 00329 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00330 } 00331 if(B.n_rows == 1) 00332 { 00333 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00334 } 00335 else 00336 { 00337 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1)); 00338 } 00339 } 00340 00341 00342 }
void glue_times::apply_mixed | ( | Mat< typename promote_type< eT1, eT2 >::result > & | out, | |
const Mat< eT1 > & | X, | |||
const Mat< eT2 > & | Y | |||
) | [inline, static, inherited] |
matrix multiplication with different element types
Definition at line 350 of file glue_times_meat.hpp.
References apply(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by operator*().
00351 { 00352 arma_extra_debug_sigprint(); 00353 00354 typedef typename promote_type<eT1,eT2>::result out_eT; 00355 00356 arma_debug_assert_mul_size(X,Y, "matrix multiply"); 00357 00358 out.set_size(X.n_rows,Y.n_cols); 00359 gemm_mixed<>::apply(out, X, Y); 00360 }
arma_inline u32 glue_times::mul_storage_cost | ( | const Mat< eT > & | A, | |
const Mat< eT > & | B, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B | |||
) | [inline, static, inherited] |
Definition at line 367 of file glue_times_meat.hpp.
References Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by apply().
arma_hot void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 382 of file glue_times_meat.hpp.
References gemm< do_trans_A, do_trans_B, use_alpha, use_beta >::apply(), gemv< do_trans_A, use_alpha, use_beta >::apply(), Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and Mat< eT >::set_size().
00391 { 00392 arma_extra_debug_sigprint(); 00393 00394 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply"); 00395 00396 const u32 final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; 00397 const u32 final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; 00398 00399 out.set_size(final_n_rows, final_n_cols); 00400 00401 // TODO: thoroughly test all combinations 00402 00403 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) 00404 { 00405 if(A.n_rows == 1) 00406 { 00407 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); 00408 } 00409 else 00410 if(B.n_cols == 1) 00411 { 00412 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); 00413 } 00414 else 00415 { 00416 gemm<false, false, false, false>::apply(out, A, B); 00417 } 00418 } 00419 else 00420 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) 00421 { 00422 if(A.n_rows == 1) 00423 { 00424 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00425 } 00426 else 00427 if(B.n_cols == 1) 00428 { 00429 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00430 } 00431 else 00432 { 00433 gemm<false, false, true, false>::apply(out, A, B, alpha); 00434 } 00435 } 00436 else 00437 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) 00438 { 00439 if(A.n_cols == 1) 00440 { 00441 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); 00442 } 00443 else 00444 if(B.n_cols == 1) 00445 { 00446 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); 00447 } 00448 else 00449 { 00450 gemm<true, false, false, false>::apply(out, A, B); 00451 } 00452 } 00453 else 00454 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) 00455 { 00456 if(A.n_cols == 1) 00457 { 00458 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00459 } 00460 if(B.n_cols == 1) 00461 { 00462 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00463 } 00464 else 00465 { 00466 gemm<true, false, true, false>::apply(out, A, B, alpha); 00467 } 00468 } 00469 else 00470 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) 00471 { 00472 if(A.n_rows == 1) 00473 { 00474 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); 00475 } 00476 if(B.n_rows == 1) 00477 { 00478 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); 00479 } 00480 else 00481 { 00482 gemm<false, true, false, false>::apply(out, A, B); 00483 } 00484 } 00485 else 00486 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) 00487 { 00488 if(A.n_rows == 1) 00489 { 00490 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00491 } 00492 if(B.n_rows == 1) 00493 { 00494 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00495 } 00496 else 00497 { 00498 gemm<false, true, true, false>::apply(out, A, B, alpha); 00499 } 00500 } 00501 else 00502 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) 00503 { 00504 if(A.n_cols == 1) 00505 { 00506 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); 00507 } 00508 if(B.n_rows == 1) 00509 { 00510 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); 00511 } 00512 else 00513 { 00514 gemm<true, true, false, false>::apply(out, A, B); 00515 } 00516 } 00517 else 00518 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) 00519 { 00520 if(A.n_cols == 1) 00521 { 00522 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00523 } 00524 if(B.n_rows == 1) 00525 { 00526 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00527 } 00528 else 00529 { 00530 gemm<true, true, true, false>::apply(out, A, B, alpha); 00531 } 00532 } 00533 }
void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const Mat< eT > & | C, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_trans_C, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 541 of file glue_times_meat.hpp.
References apply(), and mul_storage_cost().
00552 { 00553 arma_extra_debug_sigprint(); 00554 00555 Mat<eT> tmp; 00556 00557 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) ) 00558 { 00559 // out = (A*B)*C 00560 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha); 00561 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false ); 00562 } 00563 else 00564 { 00565 // out = A*(B*C) 00566 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha); 00567 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false ); 00568 } 00569 }
void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const Mat< eT > & | C, | |||
const Mat< eT > & | D, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_trans_C, | |||
const bool | do_trans_D, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 577 of file glue_times_meat.hpp.
References apply(), and mul_storage_cost().
00590 { 00591 arma_extra_debug_sigprint(); 00592 00593 Mat<eT> tmp; 00594 00595 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) ) 00596 { 00597 // out = (A*B*C)*D 00598 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); 00599 00600 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false); 00601 } 00602 else 00603 { 00604 // out = A*(B*C*D) 00605 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha); 00606 00607 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false); 00608 } 00609 }
arma_hot void glue_times_diag::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times_diag > & | X | |||
) | [inline, static, inherited] |
Definition at line 621 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, Mat< eT >::at(), Glue< T1, T2, glue_type >::B, Mat< eT >::colptr(), strip_diagmat< T1 >::do_diagmat, unwrap_check< T1 >::M, strip_diagmat< T1 >::M, Mat< eT >::n_cols, diagmat_proxy_check< T1 >::n_elem, Mat< eT >::n_rows, Mat< eT >::set_size(), and Mat< eT >::zeros().
00622 { 00623 arma_extra_debug_sigprint(); 00624 00625 typedef typename T1::elem_type eT; 00626 00627 const strip_diagmat<T1> S1(X.A); 00628 const strip_diagmat<T2> S2(X.B); 00629 00630 typedef typename strip_diagmat<T1>::stored_type T1_stripped; 00631 typedef typename strip_diagmat<T2>::stored_type T2_stripped; 00632 00633 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) ) 00634 { 00635 const diagmat_proxy_check<T1_stripped> A(S1.M, out); 00636 00637 const unwrap_check<T2> tmp(X.B, out); 00638 const Mat<eT>& B = tmp.M; 00639 00640 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiply"); 00641 00642 out.set_size(A.n_elem, B.n_cols); 00643 00644 for(u32 col=0; col<B.n_cols; ++col) 00645 { 00646 eT* out_coldata = out.colptr(col); 00647 const eT* B_coldata = B.colptr(col); 00648 00649 for(u32 row=0; row<B.n_rows; ++row) 00650 { 00651 out_coldata[row] = A[row] * B_coldata[row]; 00652 } 00653 } 00654 } 00655 else 00656 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) ) 00657 { 00658 const unwrap_check<T1> tmp(X.A, out); 00659 const Mat<eT>& A = tmp.M; 00660 00661 const diagmat_proxy_check<T2_stripped> B(S2.M, out); 00662 00663 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiply"); 00664 00665 out.set_size(A.n_rows, B.n_elem); 00666 00667 for(u32 col=0; col<A.n_cols; ++col) 00668 { 00669 const eT val = B[col]; 00670 00671 eT* out_coldata = out.colptr(col); 00672 const eT* A_coldata = A.colptr(col); 00673 00674 for(u32 row=0; row<A.n_rows; ++row) 00675 { 00676 out_coldata[row] = A_coldata[row] * val; 00677 } 00678 } 00679 } 00680 else 00681 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) ) 00682 { 00683 const diagmat_proxy_check<T1_stripped> A(S1.M, out); 00684 const diagmat_proxy_check<T2_stripped> B(S2.M, out); 00685 00686 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiply"); 00687 00688 out.zeros(A.n_elem, A.n_elem); 00689 00690 for(u32 i=0; i<A.n_elem; ++i) 00691 { 00692 out.at(i,i) = A[i] * B[i]; 00693 } 00694 } 00695 }