ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/MitPhysics/Utils/src/GBRTree.cxx
Revision: 1.2
Committed: Tue Apr 24 15:37:14 2012 UTC (13 years ago) by bendavid
Content type: text/plain
Branch: MAIN
CVS Tags: HEAD
Changes since 1.1: +0 -0 lines
State: FILE REMOVED
Error occurred while calculating annotation data.
Log Message:
Migrate to 52x version of GBRForest and Tree

File Contents

# Content
1
2
3 #include "../interface/GBRTree.h"
4
5 using namespace std;
6 #include "TMVA/DecisionTreeNode.h"
7 #include "TMVA/DecisionTree.h"
8
9
10 ClassImp(GBRTree)
11
12
13 //_______________________________________________________________________
14 GBRTree::GBRTree() :
15 fNIntermediateNodes(0),
16 fNTerminalNodes(0),
17 fCutIndices(0),
18 fCutVals(0),
19 fLeftIndices(0),
20 fRightIndices(0),
21 fResponses(0)
22 {
23
24 }
25
26 //_______________________________________________________________________
27 GBRTree::GBRTree(const TMVA::DecisionTree *tree) :
28 fNIntermediateNodes(0),
29 fNTerminalNodes(0)
30 {
31
32 //printf("boostweights size = %i, forest size = %i\n",bdt->GetBoostWeights().size(),bdt->GetForest().size());
33 Int_t nIntermediate = CountIntermediateNodes((TMVA::DecisionTreeNode*)tree->GetRoot());
34 Int_t nTerminal = CountTerminalNodes((TMVA::DecisionTreeNode*)tree->GetRoot());
35
36 //special case, root node is terminal
37 if (nIntermediate==0) nIntermediate = 1;
38
39 fCutIndices = new UChar_t[nIntermediate];
40 fCutVals = new Float_t[nIntermediate];
41 fLeftIndices = new Int_t[nIntermediate];
42 fRightIndices = new Int_t[nIntermediate];
43 fResponses = new Float_t[nTerminal];
44
45 AddNode((TMVA::DecisionTreeNode*)tree->GetRoot());
46
47 //special case, root node is terminal, create fake intermediate node at root
48 if (fNIntermediateNodes==0) {
49 fCutIndices[0] = 0;
50 fCutVals[0] = 0.;
51 fLeftIndices[0] = 0;
52 fRightIndices[0] = 0;
53 ++fNIntermediateNodes;
54 }
55
56
57
58
59
60 }
61
62 //_______________________________________________________________________
63 GBRTree::~GBRTree() {
64 delete [] fCutIndices;
65 delete [] fCutVals;
66 delete [] fLeftIndices;
67 delete [] fRightIndices;
68 delete [] fResponses;
69 }
70
71 //_______________________________________________________________________
72 UInt_t GBRTree::CountIntermediateNodes(const TMVA::DecisionTreeNode *node) {
73
74 if (!node->GetLeft() || !node->GetRight() || node->IsTerminal()) {
75 return 0;
76 }
77 else {
78 return 1 + CountIntermediateNodes((TMVA::DecisionTreeNode*)node->GetLeft()) + CountIntermediateNodes((TMVA::DecisionTreeNode*)node->GetRight());
79 }
80
81 }
82
83 //_______________________________________________________________________
84 UInt_t GBRTree::CountTerminalNodes(const TMVA::DecisionTreeNode *node) {
85
86 if (!node->GetLeft() || !node->GetRight() || node->IsTerminal()) {
87 return 1;
88 }
89 else {
90 return 0 + CountTerminalNodes((TMVA::DecisionTreeNode*)node->GetLeft()) + CountTerminalNodes((TMVA::DecisionTreeNode*)node->GetRight());
91 }
92
93 }
94
95
96 //_______________________________________________________________________
97 void GBRTree::AddNode(const TMVA::DecisionTreeNode *node) {
98
99 if (!node->GetLeft() || !node->GetRight() || node->IsTerminal()) {
100 fResponses[fNTerminalNodes] = node->GetResponse();
101 ++fNTerminalNodes;
102 return;
103 }
104 else {
105 Int_t thisindex = fNIntermediateNodes;
106 ++fNIntermediateNodes;
107
108 fCutIndices[thisindex] = node->GetSelector();
109 fCutVals[thisindex] = node->GetCutValue();
110
111
112
113 TMVA::DecisionTreeNode *left;
114 TMVA::DecisionTreeNode *right;
115 if (node->GetCutType()) {
116 left = (TMVA::DecisionTreeNode*)node->GetLeft();
117 right = (TMVA::DecisionTreeNode*)node->GetRight();
118 }
119 else {
120 left = (TMVA::DecisionTreeNode*)node->GetRight();
121 right = (TMVA::DecisionTreeNode*)node->GetLeft();
122 }
123
124 if (!left->GetLeft() || !left->GetRight() || left->IsTerminal()) {
125 fLeftIndices[thisindex] = -fNTerminalNodes;
126 }
127 else {
128 fLeftIndices[thisindex] = fNIntermediateNodes;
129 }
130 AddNode(left);
131
132 if (!right->GetLeft() || !right->GetRight() || right->IsTerminal()) {
133 fRightIndices[thisindex] = -fNTerminalNodes;
134 }
135 else {
136 fRightIndices[thisindex] = fNIntermediateNodes;
137 }
138 AddNode(right);
139
140 }
141
142 }