// Copyright (C) 2015 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_ #define DLIB_BOTTOM_uP_CLUSTER_Hh_ #include <queue> #include <map> #include "bottom_up_cluster_abstract.h" #include "../algs.h" #include "../matrix.h" #include "../disjoint_subsets.h" #include "../graph_utils.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace buc_impl { inline void merge_sets ( matrix<double>& dists, unsigned long dest, unsigned long src ) { for (long r = 0; r < dists.nr(); ++r) dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src)); } struct compare_dist { bool operator() ( const sample_pair& a, const sample_pair& b ) const { return a.distance() > b.distance(); } }; } // ---------------------------------------------------------------------------------------- template < typename EXP > unsigned long bottom_up_cluster ( const matrix_exp<EXP>& dists_, std::vector<unsigned long>& labels, unsigned long min_num_clusters, double max_dist = std::numeric_limits<double>::infinity() ) { matrix<double> dists = matrix_cast<double>(dists_); // make sure requires clause is not broken DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0, "\t unsigned long bottom_up_cluster()" << "\n\t Invalid inputs were given to this function." << "\n\t dists.nr(): " << dists.nr() << "\n\t dists.nc(): " << dists.nc() << "\n\t min_num_clusters: " << min_num_clusters ); using namespace buc_impl; labels.resize(dists.nr()); disjoint_subsets sets; sets.set_size(dists.nr()); if (labels.size() == 0) return 0; // push all the edges in the graph into a priority queue so the best edges to merge // come first. std::priority_queue<sample_pair, std::vector<sample_pair>, compare_dist> que; for (long r = 0; r < dists.nr(); ++r) for (long c = r+1; c < dists.nc(); ++c) que.push(sample_pair(r,c,dists(r,c))); // Now start merging nodes. for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter) { // find the next best thing to merge. double best_dist = que.top().distance(); unsigned long a = sets.find_set(que.top().index1()); unsigned long b = sets.find_set(que.top().index2()); que.pop(); // we have been merging and modifying the distances, so make sure this distance // is still valid and these guys haven't been merged already. while(a == b || best_dist < dists(a,b)) { // Haven't merged it yet, so put it back in with updated distance for // reconsideration later. if (a != b) que.push(sample_pair(a, b, dists(a, b))); best_dist = que.top().distance(); a = sets.find_set(que.top().index1()); b = sets.find_set(que.top().index2()); que.pop(); } // now merge these sets if the best distance is small enough if (best_dist > max_dist) break; unsigned long news = sets.merge_sets(a,b); unsigned long olds = (news==a)?b:a; merge_sets(dists, news, olds); } // figure out which cluster each element is in. Also make sure the labels are // contiguous. std::map<unsigned long, unsigned long> relabel; for (unsigned long r = 0; r < labels.size(); ++r) { unsigned long l = sets.find_set(r); // relabel to make contiguous if (relabel.count(l) == 0) { unsigned long next = relabel.size(); relabel[l] = next; } labels[r] = relabel[l]; } return relabel.size(); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BOTTOM_uP_CLUSTER_Hh_