gemm_emul_simple< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm]

Partial emulation of ATLAS/BLAS gemm(), non-cached version. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm.hpp>

List of all members.

Static Public Member Functions

template<typename eT >
static arma_hot void apply (Mat< eT > &C, const Mat< eT > &A, const Mat< eT > &B, const eT alpha=eT(1), const eT beta=eT(0))

Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_emul_simple< do_trans_A, do_trans_B, use_alpha, use_beta >

Partial emulation of ATLAS/BLAS gemm(), non-cached version. Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 211 of file gemm.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename eT >
static arma_hot void gemm_emul_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< eT > &  C,
const Mat< eT > &  A,
const Mat< eT > &  B,
const eT  alpha = eT(1),
const eT  beta = eT(0) 
) [inline, static]

Definition at line 221 of file gemm.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.

00228     {
00229     arma_extra_debug_sigprint();
00230     
00231     const u32 A_n_rows = A.n_rows;
00232     const u32 A_n_cols = A.n_cols;
00233     
00234     const u32 B_n_rows = B.n_rows;
00235     const u32 B_n_cols = B.n_cols;
00236     
00237     if( (do_trans_A == false) && (do_trans_B == false) )
00238       {
00239       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00240         {
00241         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00242           {
00243           const eT* B_coldata = B.colptr(col_B);
00244           
00245           eT acc = eT(0);
00246           for(u32 i = 0; i < B_n_rows; ++i)
00247             {
00248             acc += A.at(row_A,i) * B_coldata[i];
00249             }
00250           
00251           if( (use_alpha == false) && (use_beta == false) )
00252             {
00253             C.at(row_A,col_B) = acc;
00254             }
00255           else
00256           if( (use_alpha == true) && (use_beta == false) )
00257             {
00258             C.at(row_A,col_B) = alpha * acc;
00259             }
00260           else
00261           if( (use_alpha == false) && (use_beta == true) )
00262             {
00263             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00264             }
00265           else
00266           if( (use_alpha == true) && (use_beta == true) )
00267             {
00268             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00269             }
00270           }
00271         }
00272       }
00273     else
00274     if( (do_trans_A == true) && (do_trans_B == false) )
00275       {
00276       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00277         {
00278         // col_A is interpreted as row_A when storing the results in matrix C
00279         
00280         const eT* A_coldata = A.colptr(col_A);
00281         
00282         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00283           {
00284           const eT* B_coldata = B.colptr(col_B);
00285           
00286           eT acc = eT(0);
00287           for(u32 i=0; i < B_n_rows; ++i)
00288             {
00289             acc += A_coldata[i] * B_coldata[i];
00290             }
00291         
00292           if( (use_alpha == false) && (use_beta == false) )
00293             {
00294             C.at(col_A,col_B) = acc;
00295             }
00296           else
00297           if( (use_alpha == true) && (use_beta == false) )
00298             {
00299             C.at(col_A,col_B) = alpha * acc;
00300             }
00301           else
00302           if( (use_alpha == false) && (use_beta == true) )
00303             {
00304             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00305             }
00306           else
00307           if( (use_alpha == true) && (use_beta == true) )
00308             {
00309             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00310             }
00311           
00312           }
00313         }
00314       }
00315     else
00316     if( (do_trans_A == false) && (do_trans_B == true) )
00317       {
00318       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00319         {
00320         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00321           {
00322           eT acc = eT(0);
00323           for(u32 i = 0; i < B_n_cols; ++i)
00324             {
00325             acc += A.at(row_A,i) * B.at(row_B,i);
00326             }
00327           
00328           if( (use_alpha == false) && (use_beta == false) )
00329             {
00330             C.at(row_A,row_B) = acc;
00331             }
00332           else
00333           if( (use_alpha == true) && (use_beta == false) )
00334             {
00335             C.at(row_A,row_B) = alpha * acc;
00336             }
00337           else
00338           if( (use_alpha == false) && (use_beta == true) )
00339             {
00340             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00341             }
00342           else
00343           if( (use_alpha == true) && (use_beta == true) )
00344             {
00345             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00346             }
00347           }
00348         }
00349       }
00350     else
00351     if( (do_trans_A == true) && (do_trans_B == true) )
00352       {
00353       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00354         {
00355         
00356         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00357           {
00358           const eT* A_coldata = A.colptr(col_A);
00359           
00360           eT acc = eT(0);
00361           for(u32 i=0; i < A_n_rows; ++i)
00362             {
00363             acc += B.at(row_B,i) * A_coldata[i];
00364             }
00365         
00366           if( (use_alpha == false) && (use_beta == false) )
00367             {
00368             C.at(col_A,row_B) = acc;
00369             }
00370           else
00371           if( (use_alpha == true) && (use_beta == false) )
00372             {
00373             C.at(col_A,row_B) = alpha * acc;
00374             }
00375           else
00376           if( (use_alpha == false) && (use_beta == true) )
00377             {
00378             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00379             }
00380           else
00381           if( (use_alpha == true) && (use_beta == true) )
00382             {
00383             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00384             }
00385           
00386           }
00387         }
00388       
00389       }
00390     }