IT++ Logo

gmm.cpp

Go to the documentation of this file.
00001 
00030 #include <itpp/srccode/gmm.h>
00031 #include <itpp/srccode/vqtrain.h>
00032 #include <itpp/base/math/elem_math.h>
00033 #include <itpp/base/matfunc.h>
00034 #include <itpp/base/specmat.h>
00035 #include <itpp/base/random.h>
00036 #include <itpp/base/timing.h>
00037 #include <iostream>
00038 #include <fstream>
00039 
00041 
00042 namespace itpp
00043 {
00044 
00045 GMM::GMM()
00046 {
00047   d = 0;
00048   M = 0;
00049 }
00050 
00051 GMM::GMM(std::string filename)
00052 {
00053   load(filename);
00054 }
00055 
00056 GMM::GMM(int M_in, int d_in)
00057 {
00058   M = M_in;
00059   d = d_in;
00060   m = zeros(M * d);
00061   sigma = zeros(M * d);
00062   w = 1. / M * ones(M);
00063 
00064   for (int i = 0;i < M;i++) {
00065     w(i) = 1.0 / M;
00066   }
00067   compute_internals();
00068 }
00069 
00070 void GMM::init_from_vq(const vec &codebook, int dim)
00071 {
00072 
00073   mat  C(dim, dim);
00074   int  i;
00075   vec  v;
00076 
00077   d = dim;
00078   M = codebook.length() / dim;
00079 
00080   m = codebook;
00081   w = ones(M) / double(M);
00082 
00083   C.clear();
00084   for (i = 0;i < M;i++) {
00085     v = codebook.mid(i * d, d);
00086     C = C + outer_product(v, v);
00087   }
00088   C = 1. / M * C;
00089   sigma.set_length(M*d);
00090   for (i = 0;i < M;i++) {
00091     sigma.replace_mid(i*d, diag(C));
00092   }
00093 
00094   compute_internals();
00095 }
00096 
00097 void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in)
00098 {
00099   int  i, j;
00100   d = m_in.rows();
00101   M = m_in.cols();
00102 
00103   m.set_length(M*d);
00104   sigma.set_length(M*d);
00105   for (i = 0;i < M;i++) {
00106     for (j = 0;j < d;j++) {
00107       m(i*d + j) = m_in(j, i);
00108       sigma(i*d + j) = sigma_in(j, i);
00109     }
00110   }
00111   w = w_in;
00112 
00113   compute_internals();
00114 }
00115 
00116 void GMM::set_mean(const mat &m_in)
00117 {
00118   int  i, j;
00119 
00120   d = m_in.rows();
00121   M = m_in.cols();
00122 
00123   m.set_length(M*d);
00124   for (i = 0;i < M;i++) {
00125     for (j = 0;j < d;j++) {
00126       m(i*d + j) = m_in(j, i);
00127     }
00128   }
00129   compute_internals();
00130 }
00131 
00132 void GMM::set_mean(int i, const vec &means, bool compflag)
00133 {
00134   m.replace_mid(i*length(means), means);
00135   if (compflag) compute_internals();
00136 }
00137 
00138 void GMM::set_covariance(const mat &sigma_in)
00139 {
00140   int  i, j;
00141 
00142   d = sigma_in.rows();
00143   M = sigma_in.cols();
00144 
00145   sigma.set_length(M*d);
00146   for (i = 0;i < M;i++) {
00147     for (j = 0;j < d;j++) {
00148       sigma(i*d + j) = sigma_in(j, i);
00149     }
00150   }
00151   compute_internals();
00152 }
00153 
00154 void GMM::set_covariance(int i, const vec &covariances, bool compflag)
00155 {
00156   sigma.replace_mid(i*length(covariances), covariances);
00157   if (compflag) compute_internals();
00158 }
00159 
00160 void GMM::marginalize(int d_new)
00161 {
00162   it_error_if(d_new > d, "GMM.marginalize: cannot change to a larger dimension");
00163 
00164   vec  mnew(d_new*M), sigmanew(d_new*M);
00165   int  i, j;
00166 
00167   for (i = 0;i < M;i++) {
00168     for (j = 0;j < d_new;j++) {
00169       mnew(i*d_new + j) = m(i * d + j);
00170       sigmanew(i*d_new + j) = sigma(i * d + j);
00171     }
00172   }
00173   m = mnew;
00174   sigma = sigmanew;
00175   d = d_new;
00176 
00177   compute_internals();
00178 }
00179 
00180 void GMM::join(const GMM &newgmm)
00181 {
00182   if (d == 0) {
00183     w = newgmm.w;
00184     m = newgmm.m;
00185     sigma = newgmm.sigma;
00186     d = newgmm.d;
00187     M = newgmm.M;
00188   }
00189   else {
00190     it_error_if(d != newgmm.d, "GMM.join: cannot join GMMs of different dimension");
00191 
00192     w = concat(double(M) / (M + newgmm.M) * w, double(newgmm.M) / (M + newgmm.M) * newgmm.w);
00193     w = w / sum(w);
00194     m = concat(m, newgmm.m);
00195     sigma = concat(sigma, newgmm.sigma);
00196 
00197     M = M + newgmm.M;
00198   }
00199   compute_internals();
00200 }
00201 
00202 void GMM::clear()
00203 {
00204   w.set_length(0);
00205   m.set_length(0);
00206   sigma.set_length(0);
00207   d = 0;
00208   M = 0;
00209 }
00210 
00211 void GMM::save(std::string filename)
00212 {
00213   std::ofstream f(filename.c_str());
00214   int   i, j;
00215 
00216   f << M << " " << d << std::endl ;
00217   for (i = 0;i < w.length();i++) {
00218     f << w(i) << std::endl ;
00219   }
00220   for (i = 0;i < M;i++) {
00221     f << m(i*d) ;
00222     for (j = 1;j < d;j++) {
00223       f << " " << m(i*d + j) ;
00224     }
00225     f << std::endl ;
00226   }
00227   for (i = 0;i < M;i++) {
00228     f << sigma(i*d) ;
00229     for (j = 1;j < d;j++) {
00230       f << " " << sigma(i*d + j) ;
00231     }
00232     f << std::endl ;
00233   }
00234 }
00235 
00236 void GMM::load(std::string filename)
00237 {
00238   std::ifstream GMMFile(filename.c_str());
00239   int   i, j;
00240 
00241   it_error_if(!GMMFile, std::string("GMM::load : cannot open file ") + filename);
00242 
00243   GMMFile >> M >> d ;
00244 
00245 
00246   w.set_length(M);
00247   for (i = 0;i < M;i++) {
00248     GMMFile >> w(i) ;
00249   }
00250   m.set_length(M*d);
00251   for (i = 0;i < M;i++) {
00252     for (j = 0;j < d;j++) {
00253       GMMFile >> m(i*d + j) ;
00254     }
00255   }
00256   sigma.set_length(M*d);
00257   for (i = 0;i < M;i++) {
00258     for (j = 0;j < d;j++) {
00259       GMMFile >> sigma(i*d + j) ;
00260     }
00261   }
00262   compute_internals();
00263   std::cout << "  mixtures:" << M << "  dim:" << d << std::endl ;
00264 }
00265 
00266 double GMM::likelihood(const vec &x)
00267 {
00268   double fx = 0;
00269   int  i;
00270 
00271   for (i = 0;i < M;i++) {
00272     fx += w(i) * likelihood_aposteriori(x, i);
00273   }
00274   return fx;
00275 }
00276 
00277 vec GMM::likelihood_aposteriori(const vec &x)
00278 {
00279   vec  v(M);
00280   int  i;
00281 
00282   for (i = 0;i < M;i++) {
00283     v(i) = w(i) * likelihood_aposteriori(x, i);
00284   }
00285   return v;
00286 }
00287 
00288 double GMM::likelihood_aposteriori(const vec &x, int mixture)
00289 {
00290   int  j;
00291   double s;
00292 
00293   it_error_if(d != x.length(), "GMM::likelihood_aposteriori : dimensions does not match");
00294   s = 0;
00295   for (j = 0;j < d;j++) {
00296     s += normexp(mixture * d + j) * sqr(x(j) - m(mixture * d + j));
00297   }
00298   return normweight(mixture)*std::exp(s);;
00299 }
00300 
00301 void GMM::compute_internals()
00302 {
00303   int  i, j;
00304   double s;
00305   double constant = 1.0 / std::pow(2 * pi, d / 2.0);
00306 
00307   normweight.set_length(M);
00308   normexp.set_length(M*d);
00309 
00310   for (i = 0;i < M;i++) {
00311     s = 1;
00312     for (j = 0;j < d;j++) {
00313       normexp(i*d + j) = -0.5 / sigma(i * d + j);  // check time
00314       s *= sigma(i * d + j);
00315     }
00316     normweight(i) = constant / std::sqrt(s);
00317   }
00318 
00319 }
00320 
00321 vec GMM::draw_sample()
00322 {
00323   static bool first = true;
00324   static vec cumweight;
00325   double u = randu();
00326   int  k;
00327 
00328   if (first) {
00329     first = false;
00330     cumweight = cumsum(w);
00331     it_error_if(std::abs(cumweight(length(cumweight) - 1) - 1) > 1e-6, "weight does not sum to 0");
00332     cumweight(length(cumweight) - 1) = 1;
00333   }
00334   k = 0;
00335   while (u > cumweight(k)) k++;
00336 
00337   return elem_mult(sqrt(sigma.mid(k*d, d)), randn(d)) + m.mid(k*d, d);
00338 }
00339 
00340 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE)
00341 {
00342   mat   mean;
00343   int   i, j, d = TrainingData(0).length();
00344   vec   sig;
00345   GMM   gmm(M, d);
00346   vec   m(d*M);
00347   vec   sigma(d*M);
00348   vec   w(M);
00349   vec   normweight(M);
00350   vec   normexp(d*M);
00351   double  LL = 0, LLold, fx;
00352   double  constant = 1.0 / std::pow(2 * pi, d / 2.0);
00353   int   T = TrainingData.length();
00354   vec   x1;
00355   int   t, n;
00356   vec   msum(d*M);
00357   vec   sigmasum(d*M);
00358   vec   wsum(M);
00359   vec   p_aposteriori(M);
00360   vec   x2;
00361   double  s;
00362   vec   temp1, temp2;
00363   //double  MINIMUM_VARIANCE=0.03;
00364 
00365   //-----------initialization-----------------------------------
00366 
00367   mean = vqtrain(TrainingData, M, 200000, 0.5, VERBOSE);
00368   for (i = 0;i < M;i++) gmm.set_mean(i, mean.get_col(i), false);
00369   // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false);
00370   sig = zeros(d);
00371   for (i = 0;i < TrainingData.length();i++) sig += sqr(TrainingData(i));
00372   sig /= TrainingData.length();
00373   for (i = 0;i < M;i++) gmm.set_covariance(i, 0.5*sig, false);
00374 
00375   gmm.set_weight(1.0 / M*ones(M));
00376 
00377   //-----------optimization-----------------------------------
00378 
00379   tic();
00380   for (i = 0;i < M;i++) {
00381     temp1 = gmm.get_mean(i);
00382     temp2 = gmm.get_covariance(i);
00383     for (j = 0;j < d;j++) {
00384       m(i*d + j) = temp1(j);
00385       sigma(i*d + j) = temp2(j);
00386     }
00387     w(i) = gmm.get_weight(i);
00388   }
00389   for (n = 0;n < NOITER;n++) {
00390     for (i = 0;i < M;i++) {
00391       s = 1;
00392       for (j = 0;j < d;j++) {
00393         normexp(i*d + j) = -0.5 / sigma(i * d + j);  // check time
00394         s *= sigma(i * d + j);
00395       }
00396       normweight(i) = constant * w(i) / std::sqrt(s);
00397     }
00398     LLold = LL;
00399     wsum.clear();
00400     msum.clear();
00401     sigmasum.clear();
00402     LL = 0;
00403     for (t = 0;t < T;t++) {
00404       x1 = TrainingData(t);
00405       x2 = sqr(x1);
00406       fx = 0;
00407       for (i = 0;i < M;i++) {
00408         s = 0;
00409         for (j = 0;j < d;j++) {
00410           s += normexp(i * d + j) * sqr(x1(j) - m(i * d + j));
00411         }
00412         p_aposteriori(i) = normweight(i) * std::exp(s);
00413         fx += p_aposteriori(i);
00414       }
00415       p_aposteriori /= fx;
00416       LL = LL + std::log(fx);
00417 
00418       for (i = 0;i < M;i++) {
00419         wsum(i) += p_aposteriori(i);
00420         for (j = 0;j < d;j++) {
00421           msum(i*d + j) += p_aposteriori(i) * x1(j);
00422           sigmasum(i*d + j) += p_aposteriori(i) * x2(j);
00423         }
00424       }
00425     }
00426     for (i = 0;i < M;i++) {
00427       for (j = 0;j < d;j++) {
00428         m(i*d + j) = msum(i * d + j) / wsum(i);
00429         sigma(i*d + j) = sigmasum(i * d + j) / wsum(i) - sqr(m(i * d + j));
00430       }
00431       w(i) = wsum(i) / T;
00432     }
00433     LL = LL / T;
00434 
00435     if (std::abs((LL - LLold) / LL) < 1e-6) break;
00436     if (VERBOSE) {
00437       std::cout << n << ":   " << LL << "   " << std::abs((LL - LLold) / LL) << "   " << toc() <<  std::endl ;
00438       std::cout << "---------------------------------------" << std::endl ;
00439       tic();
00440     }
00441     else {
00442       std::cout << n << ": LL =  " << LL << "   " << std::abs((LL - LLold) / LL) << "\r" ;
00443       std::cout.flush();
00444     }
00445   }
00446   for (i = 0;i < M;i++) {
00447     gmm.set_mean(i, m.mid(i*d, d), false);
00448     gmm.set_covariance(i, sigma.mid(i*d, d), false);
00449   }
00450   gmm.set_weight(w);
00451   return gmm;
00452 }
00453 
00454 } // namespace itpp
00455 
SourceForge Logo

Generated on Sun Jul 26 08:36:50 2009 for IT++ by Doxygen 1.5.9