37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 SamplerOptions return_opt;
86 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
143 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
150 typedef detail::DecisionTree DecisionTree_t;
157 typedef LabelType LabelT;
225 template<
class TopologyIterator,
class ParameterIterator>
227 TopologyIterator topology_begin,
228 ParameterIterator parameter_begin,
232 trees_(treeCount, DecisionTree_t(problem_spec)),
233 ext_param_(problem_spec),
236 for(
unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
238 trees_[k].topology_ = *topology_begin;
239 trees_[k].parameters_ = *parameter_begin;
258 vigra_precondition(ext_param_.used() ==
true,
259 "RandomForest::ext_param(): "
260 "Random forest has not been trained yet.");
276 vigra_precondition(ext_param_.used() ==
false,
277 "RandomForest::set_ext_param():"
278 "Random forest has been trained! Call reset()"
279 "before specifying new extrinsic parameters.");
303 DecisionTree_t
const &
tree(
int index)
const
305 return trees_[index];
310 DecisionTree_t &
tree(
int index)
312 return trees_[index];
322 return ext_param_.column_count_;
333 return ext_param_.column_count_;
341 return ext_param_.class_count_;
348 return options_.tree_count_;
353 template<
class U,
class C1,
366 bool adjust_thresholds=
false);
368 template <
class U,
class C1,
class U2,
class C2>
373 onlineLearn(features,
383 template<
class U,
class C1,
389 void reLearnTree(MultiArrayView<2,U,C1>
const & features,
390 MultiArrayView<2,U2,C2>
const & response,
397 template<
class U,
class C1,
class U2,
class C2>
398 void reLearnTree(MultiArrayView<2, U, C1>
const & features,
399 MultiArrayView<2, U2, C2>
const & labels,
402 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
447 template <
class U,
class C1,
453 void learn( MultiArrayView<2, U, C1>
const & features,
454 MultiArrayView<2, U2,C2>
const & response,
458 Random_t
const & random);
460 template <
class U,
class C1,
465 void learn( MultiArrayView<2, U, C1>
const & features,
466 MultiArrayView<2, U2,C2>
const & response,
472 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
481 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
482 void learn( MultiArrayView<2, U, C1>
const & features,
483 MultiArrayView<2, U2,C2>
const & labels,
493 template <
class U,
class C1,
class U2,
class C2,
494 class Visitor_t,
class Split_t>
495 void learn( MultiArrayView<2, U, C1>
const & features,
496 MultiArrayView<2, U2,C2>
const & labels,
525 template <
class U,
class C1,
class U2,
class C2>
553 template <
class U,
class C,
class Stop>
556 template <
class U,
class C>
567 template <
class U,
class C>
568 LabelType
predictLabel(MultiArrayView<2, U, C>
const & features,
569 ArrayVectorView<double> prior)
const;
581 template <
class U,
class C1,
class T,
class C2>
585 vigra_precondition(features.
shape(0) == labels.
shape(0),
586 "RandomForest::predictLabels(): Label array has wrong size.");
587 for(
int k=0; k<features.
shape(0); ++k)
589 vigra_precondition(!detail::contains_nan(
rowVector(features, k)),
590 "RandomForest::predictLabels(): NaN in feature matrix.");
605 template <
class U,
class C1,
class T,
class C2>
608 LabelType nanLabel)
const
610 vigra_precondition(features.
shape(0) == labels.
shape(0),
611 "RandomForest::predictLabels(): Label array has wrong size.");
612 for(
int k=0; k<features.
shape(0); ++k)
614 if(detail::contains_nan(
rowVector(features, k)))
615 labels(k,0) = nanLabel;
630 template <
class U,
class C1,
class T,
class C2,
class Stop>
635 vigra_precondition(features.
shape(0) == labels.
shape(0),
636 "RandomForest::predictLabels(): Label array has wrong size.");
637 for(
int k=0; k<features.
shape(0); ++k)
652 template <
class U,
class C1,
class T,
class C2,
class Stop>
656 template <
class T1,
class T2,
class C>
666 template <
class U,
class C1,
class T,
class C2>
673 template <
class U,
class C1,
class T,
class C2>
683 template <
class LabelType,
class PreprocessorTag>
684 template<
class U,
class C1,
690 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1>
const & features,
691 MultiArrayView<2,U2,C2>
const & response,
697 bool adjust_thresholds)
699 online_visitor_.activate();
700 online_visitor_.adjust_thresholds=adjust_thresholds;
704 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
705 typedef UniformIntRandomFunctor<Random_t>
712 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
713 Default_Stop_t default_stop(options_);
714 typename RF_CHOOSER(Stop_t)::type stop
715 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
716 Default_Split_t default_split;
717 typename RF_CHOOSER(Split_t)::type split
718 = RF_CHOOSER(Split_t)::choose(split_, default_split);
719 rf::visitors::StopVisiting stopvisiting;
720 typedef rf::visitors::detail::VisitorNode
721 <rf::visitors::OnlineLearnVisitor,
722 typename RF_CHOOSER(Visitor_t)::type>
725 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
727 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
733 ext_param_.class_count_=0;
734 Preprocessor_t preprocessor( features, response,
735 options_, ext_param_);
738 RandFunctor_t randint ( random);
741 split.set_external_parameters(ext_param_);
742 stop.set_external_parameters(ext_param_);
746 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
752 for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
754 online_visitor_.tree_id=ii;
755 poisson_sampler.sample();
756 std::map<int,int> leaf_parents;
757 leaf_parents.clear();
759 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
761 int sample=poisson_sampler[s];
762 online_visitor_.current_label=preprocessor.response()(sample,0);
763 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
764 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
768 online_visitor_.add_to_index_list(ii,leaf,sample);
771 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
773 leaf_parents[leaf]=online_visitor_.last_node_id;
778 std::map<int,int>::iterator leaf_iterator;
779 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
781 int leaf=leaf_iterator->first;
782 int parent=leaf_iterator->second;
783 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
784 ArrayVector<Int32> indeces;
786 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
787 StackEntry_t stack_entry(indeces.begin(),
789 ext_param_.class_count_);
794 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
796 stack_entry.leftParent=parent;
800 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
801 stack_entry.rightParent=parent;
805 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
807 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
820 online_visitor_.deactivate();
823 template<
class LabelType,
class PreprocessorTag>
824 template<
class U,
class C1,
845 ext_param_.class_count_=0;
853 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
855 typename RF_CHOOSER(Stop_t)::type stop
856 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
858 typename RF_CHOOSER(Split_t)::type split
859 = RF_CHOOSER(Split_t)::choose(split_, default_split);
863 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
865 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
867 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
868 online_visitor_.activate();
871 RandFunctor_t randint ( random);
877 Preprocessor_t preprocessor( features, response,
878 options_, ext_param_);
881 split.set_external_parameters(ext_param_);
882 stop.set_external_parameters(ext_param_);
889 preprocessor.strata().end(),
890 detail::make_sampler_opt(options_)
891 .sampleSize(ext_param().actual_msample_),
898 first_stack_entry( sampler.sampledIndices().begin(),
899 sampler.sampledIndices().end(),
900 ext_param_.class_count_);
902 .set_oob_range( sampler.oobIndices().begin(),
903 sampler.oobIndices().end());
904 online_visitor_.reset_tree(treeId);
905 online_visitor_.tree_id=treeId;
906 trees_[treeId].reset();
908 .learn( preprocessor.features(),
909 preprocessor.response(),
916 .visit_after_tree( *
this,
922 online_visitor_.deactivate();
925 template <
class LabelType,
class PreprocessorTag>
926 template <
class U,
class C1,
938 Random_t
const & random)
949 vigra_precondition(features.
shape(0) == response.
shape(0),
950 "RandomForest::learn(): shape mismatch between features and response.");
957 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
959 typename RF_CHOOSER(Stop_t)::type stop
960 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
962 typename RF_CHOOSER(Split_t)::type split
963 = RF_CHOOSER(Split_t)::choose(split_, default_split);
967 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
969 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
971 if(options_.prepare_online_learning_)
972 online_visitor_.activate();
974 online_visitor_.deactivate();
978 RandFunctor_t randint ( random);
985 Preprocessor_t preprocessor( features, response,
986 options_, ext_param_);
989 split.set_external_parameters(ext_param_);
990 stop.set_external_parameters(ext_param_);
994 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
997 preprocessor.strata().end(),
998 detail::make_sampler_opt(options_)
999 .sampleSize(ext_param().actual_msample_),
1002 visitor.visit_at_beginning(*
this, preprocessor);
1005 for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1011 first_stack_entry( sampler.sampledIndices().begin(),
1012 sampler.sampledIndices().end(),
1013 ext_param_.class_count_);
1015 .set_oob_range( sampler.oobIndices().begin(),
1016 sampler.oobIndices().end());
1018 .learn( preprocessor.features(),
1019 preprocessor.response(),
1026 .visit_after_tree( *
this,
1033 visitor.visit_at_end(*
this, preprocessor);
1035 online_visitor_.deactivate();
1041 template <
class LabelType,
class Tag>
1042 template <
class U,
class C,
class Stop>
1046 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1047 "RandomForestn::predictLabel():"
1048 " Too few columns in feature matrix.");
1049 vigra_precondition(
rowCount(features) == 1,
1050 "RandomForestn::predictLabel():"
1051 " Feature matrix must have a singlerow.");
1054 predictProbabilities(features, probabilities, stop);
1055 ext_param_.to_classlabel(
argMax(probabilities), d);
1061 template <
class LabelType,
class PreprocessorTag>
1062 template <
class U,
class C>
1067 using namespace functor;
1068 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1069 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1070 vigra_precondition(
rowCount(features) == 1,
1071 "RandomForestn::predictLabel():"
1072 " Feature matrix must have a single row.");
1073 Matrix<double> prob(1,ext_param_.class_count_);
1074 predictProbabilities(features, prob);
1075 std::transform( prob.begin(), prob.end(),
1076 priors.
begin(), prob.begin(),
1079 ext_param_.to_classlabel(
argMax(prob), d);
1083 template<
class LabelType,
class PreprocessorTag>
1084 template <
class T1,
class T2,
class C>
1093 "RandomFroest::predictProbabilities():"
1094 " Feature matrix and probability matrix size mismatch.");
1097 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1098 "RandomForestn::predictProbabilities():"
1099 " Too few columns in feature matrix.");
1101 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1102 "RandomForestn::predictProbabilities():"
1103 " Probability matrix must have as many columns as there are classes.");
1106 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1109 for(
int k=0; k<options_.tree_count_; ++k)
1111 set_id=(set_id+1) % predictionSet.indices[0].size();
1112 typedef std::set<SampleRange<T1> > my_set;
1113 typedef typename my_set::iterator set_it;
1116 std::vector<std::pair<int,set_it> > stack;
1118 for(set_it i=predictionSet.ranges[set_id].begin();
1119 i!=predictionSet.ranges[set_id].end();++i)
1120 stack.push_back(std::pair<int,set_it>(2,i));
1122 int num_decisions=0;
1123 while(!stack.empty())
1125 set_it range=stack.back().second;
1126 int index=stack.back().first;
1130 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1133 trees_[k].parameters_,
1134 index).prob_begin();
1135 for(
int i=range->start;i!=range->end;++i)
1138 for(
int l=0; l<ext_param_.class_count_; ++l)
1140 prob(predictionSet.indices[set_id][i], l) +=
static_cast<T2
>(weights[l]);
1142 totalWeights[predictionSet.indices[set_id][i]] +=
static_cast<T1
>(weights[l]);
1149 if(trees_[k].topology_[index]!=i_ThresholdNode)
1151 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1153 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1154 if(range->min_boundaries[node.column()]>=node.threshold())
1157 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1160 if(range->max_boundaries[node.column()]<node.threshold())
1163 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1167 SampleRange<T1> new_range=*range;
1168 new_range.min_boundaries[node.column()]=FLT_MAX;
1169 range->max_boundaries[node.column()]=-FLT_MAX;
1170 new_range.start=new_range.end=range->end;
1172 while(i!=range->end)
1175 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1177 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1178 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1181 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1186 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1187 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1192 if(range->start==range->end)
1194 predictionSet.ranges[set_id].erase(range);
1198 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1201 if(new_range.start!=new_range.end)
1203 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1204 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1208 predictionSet.cumulativePredTime[k]=num_decisions;
1210 for(
unsigned int i=0;i<totalWeights.size();++i)
1214 for(
int l=0; l<ext_param_.class_count_; ++l)
1217 prob(i, l) /= totalWeights[i];
1219 assert(test==totalWeights[i]);
1220 assert(totalWeights[i]>0.0);
1224 template <
class LabelType,
class PreprocessorTag>
1225 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1228 MultiArrayView<2, T, C2> & prob,
1229 Stop_t & stop_)
const
1235 "RandomForestn::predictProbabilities():"
1236 " Feature matrix and probability matrix size mismatch.");
1240 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1241 "RandomForestn::predictProbabilities():"
1242 " Too few columns in feature matrix.");
1244 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1245 "RandomForestn::predictProbabilities():"
1246 " Probability matrix must have as many columns as there are classes.");
1248 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1249 Default_Stop_t default_stop(options_);
1250 typename RF_CHOOSER(Stop_t)::type & stop
1251 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1253 stop.set_external_parameters(ext_param_, tree_count());
1254 prob.init(NumericTraits<T>::zero());
1264 for(
int row=0; row <
rowCount(features); ++row)
1266 MultiArrayView<2, U, StridedArrayTag> currentRow(
rowVector(features, row));
1270 if(detail::contains_nan(currentRow))
1276 ArrayVector<double>::const_iterator weights;
1279 double totalWeight = 0.0;
1282 for(
int k=0; k<options_.tree_count_; ++k)
1285 weights = trees_[k ].predict(currentRow);
1288 int weighted = options_.predict_weighted_;
1289 for(
int l=0; l<ext_param_.class_count_; ++l)
1291 double cur_w = weights[l] * (weighted * (*(weights-1))
1293 prob(row, l) +=
static_cast<T
>(cur_w);
1295 totalWeight += cur_w;
1297 if(stop.after_prediction(weights,
1307 for(
int l=0; l< ext_param_.class_count_; ++l)
1309 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1315 template <
class LabelType,
class PreprocessorTag>
1316 template <
class U,
class C1,
class T,
class C2>
1317 void RandomForest<LabelType, PreprocessorTag>
1318 ::predictRaw(MultiArrayView<2, U, C1>
const & features,
1319 MultiArrayView<2, T, C2> & prob)
const
1325 "RandomForestn::predictProbabilities():"
1326 " Feature matrix and probability matrix size mismatch.");
1330 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1331 "RandomForestn::predictProbabilities():"
1332 " Too few columns in feature matrix.");
1334 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1335 "RandomForestn::predictProbabilities():"
1336 " Probability matrix must have as many columns as there are classes.");
1338 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1339 prob.init(NumericTraits<T>::zero());
1349 for(
int row=0; row <
rowCount(features); ++row)
1351 ArrayVector<double>::const_iterator weights;
1354 double totalWeight = 0.0;
1357 for(
int k=0; k<options_.tree_count_; ++k)
1360 weights = trees_[k ].predict(
rowVector(features, row));
1363 int weighted = options_.predict_weighted_;
1364 for(
int l=0; l<ext_param_.class_count_; ++l)
1366 double cur_w = weights[l] * (weighted * (*(weights-1))
1368 prob(row, l) +=
static_cast<T
>(cur_w);
1370 totalWeight += cur_w;
1374 prob/= options_.tree_count_;
1382 #include "random_forest/rf_algorithm.hxx"
1383 #endif // VIGRA_RANDOM_FOREST_HXX