ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/benhoob/HWW/BDT.C
Revision: 1.1
Committed: Mon Feb 14 12:39:14 2011 UTC (14 years, 3 months ago) by benhoob
Content type: text/plain
Branch: MAIN
CVS Tags: HEAD
Error occurred while calculating annotation data.
Log Message:
Initial commit

File Contents

# Content
1 #include <iostream>
2 #include <iomanip>
3 #include <fstream>
4
5 #include "tmvaglob.C"
6
7 #include "RQ_OBJECT.h"
8
9 #include "TROOT.h"
10 #include "TStyle.h"
11 #include "TPad.h"
12 #include "TCanvas.h"
13 #include "TLine.h"
14 #include "TFile.h"
15 #include "TColor.h"
16 #include "TPaveText.h"
17 #include "TObjString.h"
18 #include "TControlBar.h"
19
20 #include "TGWindow.h"
21 #include "TGButton.h"
22 #include "TGLabel.h"
23 #include "TGNumberEntry.h"
24
25 #include "TMVA/DecisionTree.h"
26 #include "TMVA/Tools.h"
27 #include "TXMLEngine.h"
28
29 // Uncomment this only if the link problem is solved. The include statement tends
30 // to use the ROOT classes rather than the local TMVA release
31 // #include "TMVA/DecisionTree.h"
32 // #include "TMVA/DecisionTreeNode.h"
33
34 // this macro displays a decision tree read in from the weight file
35
36 static const Int_t kSigColorF = TColor::GetColor( "#2244a5" ); // novel blue
37 static const Int_t kBkgColorF = TColor::GetColor( "#dd0033" ); // novel red
38 static const Int_t kIntColorF = TColor::GetColor( "#33aa77" ); // novel green
39
40 static const Int_t kSigColorT = 10;
41 static const Int_t kBkgColorT = 10;
42 static const Int_t kIntColorT = 10;
43
44 enum PlotType { EffPurity = 0 };
45
46 class StatDialogBDT {
47
48 RQ_OBJECT("StatDialogBDT")
49
50 public:
51
52 StatDialogBDT( const TGWindow* p, TString wfile = "weights/TMVAClassification_BDT.weights.txt",
53 TString methName = "BDT", Int_t itree = 0 );
54 virtual ~StatDialogBDT() {
55 TMVA::DecisionTreeNode::fgIsTraining=false;
56 fThis = 0;
57 fMain->CloseWindow();
58 fMain->Cleanup();
59 if(gROOT->GetListOfCanvases()->FindObject(fCanvas))
60 delete fCanvas;
61 }
62
63 // draw method
64 void DrawTree( Int_t itree );
65
66 void RaiseDialog() { if (fMain) { fMain->RaiseWindow(); fMain->Layout(); fMain->MapWindow(); } }
67
68 private:
69
70 TGMainFrame *fMain;
71 Int_t fItree;
72 Int_t fNtrees;
73 TCanvas* fCanvas;
74
75 TGNumberEntry* fInput;
76
77 TGHorizontalFrame* fButtons;
78 TGTextButton* fDrawButton;
79 TGTextButton* fCloseButton;
80
81 void UpdateCanvases();
82
83 // draw methods
84 TMVA::DecisionTree* ReadTree( TString * &vars, Int_t itree );
85 void DrawNode( TMVA::DecisionTreeNode *n,
86 Double_t x, Double_t y, Double_t xscale, Double_t yscale, TString* vars );
87 void GetNtrees();
88
89 TString fWfile;
90 TString fMethName;
91
92 public:
93
94 // static function for external deletion
95 static void Delete() { if (fThis != 0) { delete fThis; fThis = 0; } }
96
97 // slots
98 void SetItree(); //*SIGNAL*
99 void Redraw(); //*SIGNAL*
100 void Close(); //*SIGNAL*
101
102 private:
103
104 static StatDialogBDT* fThis;
105
106 };
107
108 StatDialogBDT* StatDialogBDT::fThis = 0;
109
110 void StatDialogBDT::SetItree()
111 {
112 fItree = Int_t(fInput->GetNumber());
113 }
114
115 void StatDialogBDT::Redraw()
116 {
117 UpdateCanvases();
118 }
119
120 void StatDialogBDT::Close()
121 {
122 delete this;
123 }
124
125 StatDialogBDT::StatDialogBDT( const TGWindow* p, TString wfile, TString methName, Int_t itree )
126 : fMain( 0 ),
127 fItree(itree),
128 fNtrees(0),
129 fCanvas(0),
130 fInput(0),
131 fButtons(0),
132 fDrawButton(0),
133 fCloseButton(0),
134 fWfile( wfile ),
135 fMethName( methName )
136 {
137 UInt_t totalWidth = 500;
138 UInt_t totalHeight = 200;
139
140 fThis = this;
141
142 TMVA::DecisionTreeNode::fgIsTraining=true;
143
144 // read number of decision trees from weight file
145 GetNtrees();
146
147 // main frame
148 fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
149
150 TGLabel *sigLab = new TGLabel( fMain, Form( "Decision tree [%i-%i]",0,fNtrees-1 ) );
151 fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
152
153 fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
154 fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
155 fInput->Resize(100,24);
156 fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
157
158 fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
159
160 fCloseButton = new TGTextButton(fButtons,"&Close");
161 fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
162
163 fDrawButton = new TGTextButton(fButtons,"&Draw");
164 fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
165
166 fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
167
168 fMain->SetWindowName("Decision tree");
169 fMain->SetWMPosition(0,0);
170 fMain->MapSubwindows();
171 fMain->Resize(fMain->GetDefaultSize());
172 fMain->MapWindow();
173
174 fInput->Connect("ValueSet(Long_t)","StatDialogBDT",this, "SetItree()");
175
176 fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
177 fDrawButton->Connect("Clicked()", "StatDialogBDT", this, "Redraw()");
178
179 fCloseButton->Connect("Clicked()", "StatDialogBDT", this, "Close()");
180 }
181
182 void StatDialogBDT::UpdateCanvases()
183 {
184 DrawTree( fItree );
185 }
186
187 void StatDialogBDT::GetNtrees()
188 {
189 if(!fWfile.EndsWith(".xml") ){
190 ifstream fin( fWfile );
191 if (!fin.good( )) { // file not found --> Error
192 cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
193 return;
194 }
195
196 TString dummy = "";
197
198 // read total number of trees, and check whether requested tree is in range
199 Int_t nc = 0;
200 while (!dummy.Contains("NTrees")) {
201 fin >> dummy;
202 nc++;
203 if (nc > 200) {
204 cout << endl;
205 cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
206 << fWfile << endl;
207 cout << "==> panic abort (please contact the TMVA authors)" << endl;
208 cout << endl;
209 exit(1);
210 }
211 }
212 fin >> dummy;
213 fNtrees = dummy.ReplaceAll("\"","").Atoi();
214 fin.close();
215 }
216 else{
217 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
218 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
219 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
220 while(ch){
221 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
222 if(nodeName=="Weights") {
223 TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
224 break;
225 }
226 ch = TMVA::gTools().xmlengine().GetNext(ch);
227 }
228 }
229 cout << "--- Found " << fNtrees << " decision trees in weight file" << endl;
230
231 }
232
233 //_______________________________________________________________________
234 void StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n,
235 Double_t x, Double_t y,
236 Double_t xscale, Double_t yscale, TString * vars)
237 {
238 // recursively puts an entries in the histogram for the node and its daughters
239 //
240 Float_t xsize=xscale*1.5;
241 Float_t ysize=yscale/3;
242 if (xsize>0.15) xsize=0.1; //xscale/2;
243 if (n->GetLeft() != NULL){
244 TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
245 a1->SetLineWidth(2);
246 a1->Draw();
247 DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
248 }
249 if (n->GetRight() != NULL){
250 TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
251 a1->SetLineWidth(2);
252 a1->Draw();
253 DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
254 }
255
256 // TPaveText *t = new TPaveText(x-xscale/2,y-yscale/2,x+xscale/2,y+yscale/2, "NDC");
257 TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
258
259 t->SetBorderSize(1);
260
261 t->SetFillStyle(1);
262 if (n->GetNodeType() == 1) { t->SetFillColor( kSigColorF ); t->SetTextColor( kSigColorT ); }
263 else if (n->GetNodeType() == -1) { t->SetFillColor( kBkgColorF ); t->SetTextColor( kBkgColorT ); }
264 else if (n->GetNodeType() == 0) { t->SetFillColor( kIntColorF ); t->SetTextColor( kIntColorT ); }
265
266 char buffer[25];
267 sprintf( buffer, "N=%f", n->GetNEvents() );
268 if (n->GetNEvents()>0) t->AddText(buffer);
269 sprintf( buffer, "S/(S+B)=%4.3f", n->GetPurity() );
270 t->AddText(buffer);
271
272 if (n->GetNodeType() == 0){
273 if (n->GetCutType()){
274 t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
275 }else{
276 t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
277 }
278 }
279
280 t->Draw();
281
282 return;
283 }
284 TMVA::DecisionTree* StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
285 {
286 cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << endl;
287 TMVA::DecisionTree *d = new TMVA::DecisionTree();
288 if(!fWfile.EndsWith(".xml") ){
289 ifstream fin( fWfile );
290 if (!fin.good( )) { // file not found --> Error
291 cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
292 return 0;
293 }
294
295 TString dummy = "";
296
297 if (itree >= fNtrees) {
298 cout << "*** ERROR: requested decision tree: " << itree
299 << ", but number of trained trees only: " << fNtrees << endl;
300 return 0;
301 }
302
303 // file header with name
304 while (!dummy.Contains("#VAR")) fin >> dummy;
305 fin >> dummy >> dummy >> dummy; // the rest of header line
306
307 // number of variables
308 Int_t nVars;
309 fin >> dummy >> nVars;
310
311 // variable mins and maxes
312 vars = new TString[nVars+1]; // last one is if "fisher cut criterium"
313 for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
314 vars[nVars]="FisherCrit";
315
316 char buffer[20];
317 char line[256];
318 sprintf(buffer,"Tree %d",itree);
319
320 while (!dummy.Contains(buffer)) {
321 fin.getline(line,256);
322 dummy = TString(line);
323 }
324
325 d->Read(fin);
326
327 fin.close();
328 }
329 else{
330 if (itree >= fNtrees) {
331 cout << "*** ERROR: requested decision tree: " << itree
332 << ", but number of trained trees only: " << fNtrees << endl;
333 return 0;
334 }
335 Int_t nVars;
336 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
337 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
338 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
339 while(ch){
340 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
341 if(nodeName=="Variables"){
342 TMVA::gTools().ReadAttr( ch, "NVar", nVars);
343 vars = new TString[nVars+1];
344 void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
345 for (Int_t i = 0; i < nVars; i++){
346 TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
347 varnode = TMVA::gTools().xmlengine().GetNext(varnode);
348 }
349 vars[nVars]="FisherCrit";
350 }
351 if(nodeName=="Weights") break;
352 ch = TMVA::gTools().xmlengine().GetNext(ch);
353 }
354 ch = TMVA::gTools().xmlengine().GetChild(ch);
355 for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
356 d->ReadXML(ch);
357 }
358 return d;
359 }
360
361 //_______________________________________________________________________
362 void StatDialogBDT::DrawTree( Int_t itree )
363 {
364 TString *vars;
365 TMVA::DecisionTree* d = ReadTree( vars, itree );
366 if (d == 0) return;
367
368 UInt_t depth = d->GetTotalTreeDepth();
369 Double_t ystep = 1.0/(depth + 1.0);
370
371 cout << "--- Tree depth: " << depth << endl;
372
373 TStyle* TMVAStyle = gROOT->GetStyle("Plain"); // our style is based on Plain
374 Int_t canvasColor = TMVAStyle->GetCanvasColor(); // backup
375
376 TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
377 TString tbuffer = Form( "Decision Tree no.: %d", itree );
378 if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 );
379 else fCanvas->Clear();
380 fCanvas->Draw();
381
382 DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
383
384 // make the legend
385 Double_t yup=0.99;
386 Double_t ydown=yup-ystep/2.5;
387 Double_t dy= ystep/2.5 * 0.2;
388
389 TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
390 whichTree->SetBorderSize(1);
391 whichTree->SetFillStyle(1);
392 whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
393 whichTree->AddText( tbuffer );
394 whichTree->Draw();
395
396 TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
397 intermediate->SetBorderSize(1);
398 intermediate->SetFillStyle(1);
399 intermediate->SetFillColor( kIntColorF );
400 intermediate->AddText("Intermediate Nodes");
401 intermediate->SetTextColor( kIntColorT );
402 intermediate->Draw();
403
404 ydown = ydown - ystep/2.5 -dy;
405 yup = yup - ystep/2.5 -dy;
406 TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
407 signalleaf->SetBorderSize(1);
408 signalleaf->SetFillStyle(1);
409 signalleaf->SetFillColor( kSigColorF );
410 signalleaf->AddText("Signal Leaf Nodes");
411 signalleaf->SetTextColor( kSigColorT );
412 signalleaf->Draw();
413
414 ydown = ydown - ystep/2.5 -dy;
415 yup = yup - ystep/2.5 -dy;
416 TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
417 backgroundleaf->SetBorderSize(1);
418 backgroundleaf->SetFillStyle(1);
419 backgroundleaf->SetFillColor( kBkgColorF );
420
421 backgroundleaf->AddText("Backgr. Leaf Nodes");
422 backgroundleaf->SetTextColor( kBkgColorT );
423 backgroundleaf->Draw();
424
425 fCanvas->Update();
426 TString fname = Form("plots/%s_%i", fMethName.Data(), itree );
427 cout << "--- Creating image: " << fname << endl;
428 TMVAGlob::imgconv( fCanvas, fname );
429
430 TMVAStyle->SetCanvasColor( canvasColor );
431 }
432
433 // ========================================================================================
434
435 static std::vector<TControlBar*> BDT_Global__cbar;
436
437 // intermediate GUI
438 void BDT( const TString& fin = "TMVA.root" )
439 {
440 // --- read the available BDT weight files
441
442 // destroy all open cavases
443 TMVAGlob::DestroyCanvases();
444
445 // checks if file with name "fin" is already open, and if not opens one
446 TFile* file = TMVAGlob::OpenFile( fin );
447
448 TDirectory* dir = file->GetDirectory( "Method_BDT" );
449 if (!dir) {
450 cout << "*** Error in macro \"BDT.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
451 return;
452 }
453
454 // read all directories
455 TIter next( dir->GetListOfKeys() );
456 TKey *key(0);
457 std::vector<TString> methname;
458 std::vector<TString> path;
459 std::vector<TString> wfile;
460 while ((key = (TKey*)next())) {
461 TDirectory* mdir = dir->GetDirectory( key->GetName() );
462 if (!mdir) {
463 cout << "*** Error in macro \"BDT.C\": cannot find sub-directory: " << key->GetName()
464 << " in directory: " << dir->GetName() << endl;
465 return;
466 }
467
468 // retrieve weight file name and path
469 TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
470 TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
471 if (!strPath || !strWFile) {
472 cout << "*** Error in macro \"BDT.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
473 cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
474 return;
475 }
476
477 methname.push_back( key->GetName() );
478 path .push_back( strPath->GetString() );
479 wfile .push_back( strWFile->GetString() );
480 }
481
482 // create the control bar
483 TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
484 BDT_Global__cbar.push_back(cbar);
485
486 for (UInt_t im=0; im<path.size(); im++) {
487 TString fname = path[im];
488 if (fname[fname.Length()-1] != '/') fname += "/";
489 fname += wfile[im];
490 TString macro = Form( ".x BDT.C+\(0,\"%s\",\"%s\")", fname.Data(), methname[im].Data() );
491 cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
492 }
493
494 // *** problems with this button in ROOT 5.19 ***
495 #if ROOT_VERSION_CODE < ROOT_VERSION(5,19,0)
496 cbar->AddButton( "Close", Form("BDT_DeleteTBar(%i)", BDT_Global__cbar.size()-1), "Close this control bar", "button" );
497 #endif
498 // **********************************************
499
500 // set the style
501 cbar->SetTextColor("blue");
502
503 // draw
504 cbar->Show();
505 }
506
507 void BDT_DeleteTBar(int i)
508 {
509 // destroy all open canvases
510 StatDialogBDT::Delete();
511 TMVAGlob::DestroyCanvases();
512
513 delete BDT_Global__cbar[i];
514 BDT_Global__cbar[i] = 0;
515 }
516
517 // input: - No. of tree
518 // - the weight file from which the tree is read
519 void BDT( Int_t itree, TString wfile = "weights/TMVAnalysis_test_BDT.weights.txt", TString methName = "BDT", Bool_t useTMVAStyle = kTRUE )
520 {
521 // destroy possibly existing dialog windows and/or canvases
522 StatDialogBDT::Delete();
523 TMVAGlob::DestroyCanvases();
524
525 // quick check if weight file exist
526 if(!wfile.EndsWith(".xml") ){
527 ifstream fin( wfile );
528 if (!fin.good( )) { // file not found --> Error
529 cout << "*** ERROR: Weight file: " << wfile << " does not exist" << endl;
530 return;
531 }
532 }
533 std::cout << "test1";
534 // set style and remove existing canvas'
535 TMVAGlob::Initialize( useTMVAStyle );
536
537 StatDialogBDT* gGui = new StatDialogBDT( gClient->GetRoot(), wfile, methName, itree );
538
539 gGui->DrawTree( itree );
540
541 gGui->RaiseDialog();
542 }
543