ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/MitPhysics/Utils/interface/GBRTree.h
Revision: 1.1
Committed: Wed Sep 28 00:08:48 2011 UTC (13 years, 7 months ago) by bendavid
Content type: text/plain
Branch: MAIN
CVS Tags: Mit_025e, Mit_025d, Mit_025c, Mit_025b, Mit_025a, Mit_025, Mit_025pre2
Log Message:
optimized bdt implementation for regression, so that disk and memory requirements are manageable

File Contents

# User Rev Content
1 bendavid 1.1
2     #ifndef ROOT_GBRTree
3     #define ROOT_GBRTree
4    
5     //////////////////////////////////////////////////////////////////////////
6     // //
7     // GBRForest //
8     // //
9     // A fast minimal implementation of Gradient-Boosted Regression Trees //
10     // which has been especially optimized for size on disk and in memory. //
11     // //
12     // Designed to built from TMVA-trained trees, but could also be //
13     // generalized to otherwise-trained trees, classification, //
14     // or other boosting methods in the future //
15     // //
16     // Josh Bendavid - MIT //
17     //////////////////////////////////////////////////////////////////////////
18    
19     // The decision tree is implemented here as a set of two arrays, one for
20     // intermediate nodes, containing the variable index and cut value, as well
21     // as the indices of the 'left' and 'right' daughter nodes. Positive indices
22     // indicate further intermediate nodes, whereas negative indices indicate
23     // terminal nodes, which are stored simply as a vector of regression responses
24    
25    
26     #include <vector>
27     #include <map>
28     #include "Rtypes.h"
29    
30    
31     namespace TMVA {
32     class DecisionTree;
33     class DecisionTreeNode;
34     }
35    
36     class GBRTree {
37    
38     public:
39    
40     GBRTree();
41     GBRTree(const TMVA::DecisionTree *tree);
42    
43     virtual ~GBRTree();
44    
45     Double_t GetResponse(const Float_t* vector) const;
46    
47     protected:
48    
49     UInt_t CountIntermediateNodes(const TMVA::DecisionTreeNode *node);
50     UInt_t CountTerminalNodes(const TMVA::DecisionTreeNode *node);
51    
52     void AddNode(const TMVA::DecisionTreeNode *node);
53    
54     Int_t fNIntermediateNodes;
55     Int_t fNTerminalNodes;
56    
57     UChar_t *fCutIndices;//[fNIntermediateNodes]
58     Float_t *fCutVals;//[fNIntermediateNodes]
59     Int_t *fLeftIndices;//[fNIntermediateNodes]
60     Int_t *fRightIndices;//[fNIntermediateNodes]
61     Float_t *fResponses;//[fNTerminalNodes]
62    
63    
64     ClassDef(GBRTree,1) // Node for the Decision Tree
65     };
66    
67     //_______________________________________________________________________
68     inline Double_t GBRTree::GetResponse(const Float_t* vector) const {
69    
70     Int_t index = 0;
71    
72     UChar_t cutindex = fCutIndices[0];
73     Float_t cutval = fCutVals[0];
74    
75     while (true) {
76     if (vector[cutindex] > cutval) {
77     index = fRightIndices[index];
78     }
79     else {
80     index = fLeftIndices[index];
81     }
82    
83     if (index>0) {
84     cutindex = fCutIndices[index];
85     cutval = fCutVals[index];
86     }
87     else {
88     return fResponses[-index];
89     }
90    
91     }
92    
93    
94     }
95    
96     #endif