ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/MitPhysics/Utils/src/GBRTree.cxx
Revision: 1.1
Committed: Wed Sep 28 00:08:48 2011 UTC (13 years, 7 months ago) by bendavid
Content type: text/plain
Branch: MAIN
CVS Tags: Mit_025e, Mit_025d, Mit_025c, Mit_025b, Mit_025a, Mit_025, Mit_025pre2
Log Message:
optimized bdt implementation for regression, so that disk and memory requirements are manageable

File Contents

# User Rev Content
1 bendavid 1.1
2    
3     #include "../interface/GBRTree.h"
4    
5     using namespace std;
6     #include "TMVA/DecisionTreeNode.h"
7     #include "TMVA/DecisionTree.h"
8    
9    
10     ClassImp(GBRTree)
11    
12    
13     //_______________________________________________________________________
14     GBRTree::GBRTree() :
15     fNIntermediateNodes(0),
16     fNTerminalNodes(0),
17     fCutIndices(0),
18     fCutVals(0),
19     fLeftIndices(0),
20     fRightIndices(0),
21     fResponses(0)
22     {
23    
24     }
25    
26     //_______________________________________________________________________
27     GBRTree::GBRTree(const TMVA::DecisionTree *tree) :
28     fNIntermediateNodes(0),
29     fNTerminalNodes(0)
30     {
31    
32     //printf("boostweights size = %i, forest size = %i\n",bdt->GetBoostWeights().size(),bdt->GetForest().size());
33     Int_t nIntermediate = CountIntermediateNodes((TMVA::DecisionTreeNode*)tree->GetRoot());
34     Int_t nTerminal = CountTerminalNodes((TMVA::DecisionTreeNode*)tree->GetRoot());
35    
36     //special case, root node is terminal
37     if (nIntermediate==0) nIntermediate = 1;
38    
39     fCutIndices = new UChar_t[nIntermediate];
40     fCutVals = new Float_t[nIntermediate];
41     fLeftIndices = new Int_t[nIntermediate];
42     fRightIndices = new Int_t[nIntermediate];
43     fResponses = new Float_t[nTerminal];
44    
45     AddNode((TMVA::DecisionTreeNode*)tree->GetRoot());
46    
47     //special case, root node is terminal, create fake intermediate node at root
48     if (fNIntermediateNodes==0) {
49     fCutIndices[0] = 0;
50     fCutVals[0] = 0.;
51     fLeftIndices[0] = 0;
52     fRightIndices[0] = 0;
53     ++fNIntermediateNodes;
54     }
55    
56    
57    
58    
59    
60     }
61    
62     //_______________________________________________________________________
63     GBRTree::~GBRTree() {
64     delete [] fCutIndices;
65     delete [] fCutVals;
66     delete [] fLeftIndices;
67     delete [] fRightIndices;
68     delete [] fResponses;
69     }
70    
71     //_______________________________________________________________________
72     UInt_t GBRTree::CountIntermediateNodes(const TMVA::DecisionTreeNode *node) {
73    
74     if (!node->GetLeft() || !node->GetRight() || node->IsTerminal()) {
75     return 0;
76     }
77     else {
78     return 1 + CountIntermediateNodes((TMVA::DecisionTreeNode*)node->GetLeft()) + CountIntermediateNodes((TMVA::DecisionTreeNode*)node->GetRight());
79     }
80    
81     }
82    
83     //_______________________________________________________________________
84     UInt_t GBRTree::CountTerminalNodes(const TMVA::DecisionTreeNode *node) {
85    
86     if (!node->GetLeft() || !node->GetRight() || node->IsTerminal()) {
87     return 1;
88     }
89     else {
90     return 0 + CountTerminalNodes((TMVA::DecisionTreeNode*)node->GetLeft()) + CountTerminalNodes((TMVA::DecisionTreeNode*)node->GetRight());
91     }
92    
93     }
94    
95    
96     //_______________________________________________________________________
97     void GBRTree::AddNode(const TMVA::DecisionTreeNode *node) {
98    
99     if (!node->GetLeft() || !node->GetRight() || node->IsTerminal()) {
100     fResponses[fNTerminalNodes] = node->GetResponse();
101     ++fNTerminalNodes;
102     return;
103     }
104     else {
105     Int_t thisindex = fNIntermediateNodes;
106     ++fNIntermediateNodes;
107    
108     fCutIndices[thisindex] = node->GetSelector();
109     fCutVals[thisindex] = node->GetCutValue();
110    
111    
112    
113     TMVA::DecisionTreeNode *left;
114     TMVA::DecisionTreeNode *right;
115     if (node->GetCutType()) {
116     left = (TMVA::DecisionTreeNode*)node->GetLeft();
117     right = (TMVA::DecisionTreeNode*)node->GetRight();
118     }
119     else {
120     left = (TMVA::DecisionTreeNode*)node->GetRight();
121     right = (TMVA::DecisionTreeNode*)node->GetLeft();
122     }
123    
124     if (!left->GetLeft() || !left->GetRight() || left->IsTerminal()) {
125     fLeftIndices[thisindex] = -fNTerminalNodes;
126     }
127     else {
128     fLeftIndices[thisindex] = fNIntermediateNodes;
129     }
130     AddNode(left);
131    
132     if (!right->GetLeft() || !right->GetRight() || right->IsTerminal()) {
133     fRightIndices[thisindex] = -fNTerminalNodes;
134     }
135     else {
136     fRightIndices[thisindex] = fNIntermediateNodes;
137     }
138     AddNode(right);
139    
140     }
141    
142     }