ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/MitPhysics/Utils/interface/GBRTree.h
Revision: 1.1
Committed: Wed Sep 28 00:08:48 2011 UTC (13 years, 7 months ago) by bendavid
Content type: text/plain
Branch: MAIN
CVS Tags: Mit_025e, Mit_025d, Mit_025c, Mit_025b, Mit_025a, Mit_025, Mit_025pre2
Log Message:
optimized bdt implementation for regression, so that disk and memory requirements are manageable

File Contents

# Content
1
2 #ifndef ROOT_GBRTree
3 #define ROOT_GBRTree
4
5 //////////////////////////////////////////////////////////////////////////
6 // //
7 // GBRForest //
8 // //
9 // A fast minimal implementation of Gradient-Boosted Regression Trees //
10 // which has been especially optimized for size on disk and in memory. //
11 // //
12 // Designed to built from TMVA-trained trees, but could also be //
13 // generalized to otherwise-trained trees, classification, //
14 // or other boosting methods in the future //
15 // //
16 // Josh Bendavid - MIT //
17 //////////////////////////////////////////////////////////////////////////
18
19 // The decision tree is implemented here as a set of two arrays, one for
20 // intermediate nodes, containing the variable index and cut value, as well
21 // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22 // indicate further intermediate nodes, whereas negative indices indicate
23 // terminal nodes, which are stored simply as a vector of regression responses
24
25
26 #include <vector>
27 #include <map>
28 #include "Rtypes.h"
29
30
31 namespace TMVA {
32 class DecisionTree;
33 class DecisionTreeNode;
34 }
35
36 class GBRTree {
37
38 public:
39
40 GBRTree();
41 GBRTree(const TMVA::DecisionTree *tree);
42
43 virtual ~GBRTree();
44
45 Double_t GetResponse(const Float_t* vector) const;
46
47 protected:
48
49 UInt_t CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
50 UInt_t CountTerminalNodes(const TMVA::DecisionTreeNode *node);
51
52 void AddNode(const TMVA::DecisionTreeNode *node);
53
54 Int_t fNIntermediateNodes;
55 Int_t fNTerminalNodes;
56
57 UChar_t *fCutIndices;//[fNIntermediateNodes]
58 Float_t *fCutVals;//[fNIntermediateNodes]
59 Int_t *fLeftIndices;//[fNIntermediateNodes]
60 Int_t *fRightIndices;//[fNIntermediateNodes]
61 Float_t *fResponses;//[fNTerminalNodes]
62
63
64 ClassDef(GBRTree,1) // Node for the Decision Tree
65 };
66
67 //_______________________________________________________________________
68 inline Double_t GBRTree::GetResponse(const Float_t* vector) const {
69
70 Int_t index = 0;
71
72 UChar_t cutindex = fCutIndices[0];
73 Float_t cutval = fCutVals[0];
74
75 while (true) {
76 if (vector[cutindex] > cutval) {
77 index = fRightIndices[index];
78 }
79 else {
80 index = fLeftIndices[index];
81 }
82
83 if (index>0) {
84 cutindex = fCutIndices[index];
85 cutval = fCutVals[index];
86 }
87 else {
88 return fResponses[-index];
89 }
90
91 }
92
93
94 }
95
96 #endif