op_dot_meat.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 //! \addtogroup op_dot
00018 //! @{
00019 
00020 
00021 
00022 //! for two arrays
00023 template<typename eT>
00024 inline
00025 arma_hot
00026 arma_pure
00027 eT
00028 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B)
00029   {
00030   arma_extra_debug_sigprint();
00031   
00032   eT val1 = eT(0);
00033   eT val2 = eT(0);
00034   
00035   u32 i,j;
00036   for(i=0, j=1; j<n_elem; i+=2, j+=2)
00037     {
00038     val1 += A[i] * B[i];
00039     val2 += A[j] * B[j];
00040     }
00041   
00042   if(i < n_elem)
00043     {
00044     val1 += A[i] * B[i];
00045     }
00046   
00047   return val1+val2;
00048   }
00049 
00050 
00051 
00052 //! for three arrays
00053 template<typename eT>
00054 inline
00055 arma_hot
00056 arma_pure
00057 eT
00058 op_dot::direct_dot(const u32 n_elem, const eT* const A, const eT* const B, const eT* C)
00059   {
00060   arma_extra_debug_sigprint();
00061   
00062   eT val = eT(0);
00063   
00064   for(u32 i=0; i<n_elem; ++i)
00065     {
00066     val += A[i] * B[i] * C[i];
00067     }
00068 
00069   return val;
00070   }
00071 
00072 
00073 
00074 template<typename T1, typename T2>
00075 arma_inline
00076 arma_hot
00077 typename T1::elem_type
00078 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00079   {
00080   arma_extra_debug_sigprint();
00081   
00082   if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00083     {
00084     return op_dot::apply_unwrap(X,Y);
00085     }
00086   else
00087     {
00088     return op_dot::apply_proxy(X,Y);
00089     }
00090   }
00091 
00092 
00093 
00094 template<typename T1, typename T2>
00095 arma_inline
00096 arma_hot
00097 typename T1::elem_type
00098 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00099   {
00100   arma_extra_debug_sigprint();
00101   
00102   typedef typename T1::elem_type eT;
00103   
00104   const unwrap<T1> tmp1(X.get_ref());
00105   const unwrap<T2> tmp2(Y.get_ref());
00106   
00107   const Mat<eT>& A = tmp1.M;
00108   const Mat<eT>& B = tmp2.M;
00109   
00110   arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00111   
00112   return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
00113   }
00114 
00115 
00116 
00117 template<typename T1, typename T2>
00118 inline
00119 arma_hot
00120 typename T1::elem_type
00121 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00122   {
00123   arma_extra_debug_sigprint();
00124   
00125   typedef typename T1::elem_type eT;
00126   
00127   const Proxy<T1> A(X.get_ref());
00128   const Proxy<T2> B(Y.get_ref());
00129   
00130   arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
00131   
00132   const u32 n_elem = A.n_elem;
00133   eT val = eT(0);
00134   
00135   for(u32 i=0; i<n_elem; ++i)
00136     {
00137     val += A[i] * B[i];
00138     }
00139   
00140   return val;
00141   }
00142 
00143 
00144 
00145 //
00146 
00147 
00148 
00149 template<typename T1, typename T2>
00150 arma_inline
00151 arma_hot
00152 typename T1::elem_type
00153 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00154   {
00155   arma_extra_debug_sigprint();
00156   
00157   if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
00158     {
00159     return op_norm_dot::apply_unwrap(X,Y);
00160     }
00161   else
00162     {
00163     return op_norm_dot::apply_proxy(X,Y);
00164     }
00165   }
00166 
00167 
00168 
00169 template<typename T1, typename T2>
00170 inline
00171 arma_hot
00172 typename T1::elem_type
00173 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00174   {
00175   arma_extra_debug_sigprint();
00176   
00177   typedef typename T1::elem_type eT;
00178   
00179   const unwrap<T1> tmp1(X.get_ref());
00180   const unwrap<T2> tmp2(Y.get_ref());
00181   
00182   const Mat<eT>& A = tmp1.M;
00183   const Mat<eT>& B = tmp2.M;
00184 
00185   arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00186   
00187   const eT* A_mem = A.memptr();
00188   const eT* B_mem = B.memptr();
00189   
00190   const u32 N = A.n_elem;
00191   
00192   eT acc1 = eT(0);
00193   eT acc2 = eT(0);
00194   eT acc3 = eT(0);
00195   
00196   for(u32 i=0; i<N; ++i)
00197     {
00198     const eT tmpA = A_mem[i];
00199     const eT tmpB = B_mem[i];
00200     
00201     acc1 += tmpA * tmpA;
00202     acc2 += tmpB * tmpB;
00203     acc3 += tmpA * tmpB;
00204     }
00205     
00206   return acc3 / ( std::sqrt(acc1 * acc2) );   // TODO: this only makes sense for eT = float, double or complex
00207   }
00208 
00209 
00210 
00211 template<typename T1, typename T2>
00212 inline
00213 arma_hot
00214 typename T1::elem_type
00215 op_norm_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
00216   {
00217   arma_extra_debug_sigprint();
00218   
00219   typedef typename T1::elem_type eT;
00220   
00221   const Proxy<T1> A(X.get_ref());
00222   const Proxy<T2> B(Y.get_ref());
00223 
00224   arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
00225   
00226   const u32 N = A.n_elem;
00227   
00228   eT acc1 = eT(0);
00229   eT acc2 = eT(0);
00230   eT acc3 = eT(0);
00231   
00232   for(u32 i=0; i<N; ++i)
00233     {
00234     const eT tmpA = A[i];
00235     const eT tmpB = B[i];
00236     
00237     acc1 += tmpA * tmpA;
00238     acc2 += tmpB * tmpB;
00239     acc3 += tmpA * tmpB;
00240     }
00241     
00242   return acc3 / ( std::sqrt(acc1 * acc2) );   // TODO: this only makes sense for eT = float, double or complex
00243   }
00244 
00245 
00246 
00247 //! @}