ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/benhoob/HWW/BDT.C
Revision: 1.1
Committed: Mon Feb 14 12:39:14 2011 UTC (14 years, 3 months ago) by benhoob
Content type: text/plain
Branch: MAIN
CVS Tags: HEAD
Log Message:
Initial commit

File Contents

# User Rev Content
1 benhoob 1.1 #include <iostream>
2     #include <iomanip>
3     #include <fstream>
4    
5     #include "tmvaglob.C"
6    
7     #include "RQ_OBJECT.h"
8    
9     #include "TROOT.h"
10     #include "TStyle.h"
11     #include "TPad.h"
12     #include "TCanvas.h"
13     #include "TLine.h"
14     #include "TFile.h"
15     #include "TColor.h"
16     #include "TPaveText.h"
17     #include "TObjString.h"
18     #include "TControlBar.h"
19    
20     #include "TGWindow.h"
21     #include "TGButton.h"
22     #include "TGLabel.h"
23     #include "TGNumberEntry.h"
24    
25     #include "TMVA/DecisionTree.h"
26     #include "TMVA/Tools.h"
27     #include "TXMLEngine.h"
28    
29     // Uncomment this only if the link problem is solved. The include statement tends
30     // to use the ROOT classes rather than the local TMVA release
31     // #include "TMVA/DecisionTree.h"
32     // #include "TMVA/DecisionTreeNode.h"
33    
34     // this macro displays a decision tree read in from the weight file
35    
36     static const Int_t kSigColorF = TColor::GetColor( "#2244a5" ); // novel blue
37     static const Int_t kBkgColorF = TColor::GetColor( "#dd0033" ); // novel red
38     static const Int_t kIntColorF = TColor::GetColor( "#33aa77" ); // novel green
39    
40     static const Int_t kSigColorT = 10;
41     static const Int_t kBkgColorT = 10;
42     static const Int_t kIntColorT = 10;
43    
44     enum PlotType { EffPurity = 0 };
45    
46     class StatDialogBDT {
47    
48     RQ_OBJECT("StatDialogBDT")
49    
50     public:
51    
52     StatDialogBDT( const TGWindow* p, TString wfile = "weights/TMVAClassification_BDT.weights.txt",
53     TString methName = "BDT", Int_t itree = 0 );
54     virtual ~StatDialogBDT() {
55     TMVA::DecisionTreeNode::fgIsTraining=false;
56     fThis = 0;
57     fMain->CloseWindow();
58     fMain->Cleanup();
59     if(gROOT->GetListOfCanvases()->FindObject(fCanvas))
60     delete fCanvas;
61     }
62    
63     // draw method
64     void DrawTree( Int_t itree );
65    
66     void RaiseDialog() { if (fMain) { fMain->RaiseWindow(); fMain->Layout(); fMain->MapWindow(); } }
67    
68     private:
69    
70     TGMainFrame *fMain;
71     Int_t fItree;
72     Int_t fNtrees;
73     TCanvas* fCanvas;
74    
75     TGNumberEntry* fInput;
76    
77     TGHorizontalFrame* fButtons;
78     TGTextButton* fDrawButton;
79     TGTextButton* fCloseButton;
80    
81     void UpdateCanvases();
82    
83     // draw methods
84     TMVA::DecisionTree* ReadTree( TString * &vars, Int_t itree );
85     void DrawNode( TMVA::DecisionTreeNode *n,
86     Double_t x, Double_t y, Double_t xscale, Double_t yscale, TString* vars );
87     void GetNtrees();
88    
89     TString fWfile;
90     TString fMethName;
91    
92     public:
93    
94     // static function for external deletion
95     static void Delete() { if (fThis != 0) { delete fThis; fThis = 0; } }
96    
97     // slots
98     void SetItree(); //*SIGNAL*
99     void Redraw(); //*SIGNAL*
100     void Close(); //*SIGNAL*
101    
102     private:
103    
104     static StatDialogBDT* fThis;
105    
106     };
107    
108     StatDialogBDT* StatDialogBDT::fThis = 0;
109    
110     void StatDialogBDT::SetItree()
111     {
112     fItree = Int_t(fInput->GetNumber());
113     }
114    
115     void StatDialogBDT::Redraw()
116     {
117     UpdateCanvases();
118     }
119    
120     void StatDialogBDT::Close()
121     {
122     delete this;
123     }
124    
125     StatDialogBDT::StatDialogBDT( const TGWindow* p, TString wfile, TString methName, Int_t itree )
126     : fMain( 0 ),
127     fItree(itree),
128     fNtrees(0),
129     fCanvas(0),
130     fInput(0),
131     fButtons(0),
132     fDrawButton(0),
133     fCloseButton(0),
134     fWfile( wfile ),
135     fMethName( methName )
136     {
137     UInt_t totalWidth = 500;
138     UInt_t totalHeight = 200;
139    
140     fThis = this;
141    
142     TMVA::DecisionTreeNode::fgIsTraining=true;
143    
144     // read number of decision trees from weight file
145     GetNtrees();
146    
147     // main frame
148     fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
149    
150     TGLabel *sigLab = new TGLabel( fMain, Form( "Decision tree [%i-%i]",0,fNtrees-1 ) );
151     fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
152    
153     fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
154     fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
155     fInput->Resize(100,24);
156     fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
157    
158     fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
159    
160     fCloseButton = new TGTextButton(fButtons,"&Close");
161     fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
162    
163     fDrawButton = new TGTextButton(fButtons,"&Draw");
164     fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
165    
166     fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
167    
168     fMain->SetWindowName("Decision tree");
169     fMain->SetWMPosition(0,0);
170     fMain->MapSubwindows();
171     fMain->Resize(fMain->GetDefaultSize());
172     fMain->MapWindow();
173    
174     fInput->Connect("ValueSet(Long_t)","StatDialogBDT",this, "SetItree()");
175    
176     fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
177     fDrawButton->Connect("Clicked()", "StatDialogBDT", this, "Redraw()");
178    
179     fCloseButton->Connect("Clicked()", "StatDialogBDT", this, "Close()");
180     }
181    
182     void StatDialogBDT::UpdateCanvases()
183     {
184     DrawTree( fItree );
185     }
186    
187     void StatDialogBDT::GetNtrees()
188     {
189     if(!fWfile.EndsWith(".xml") ){
190     ifstream fin( fWfile );
191     if (!fin.good( )) { // file not found --> Error
192     cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
193     return;
194     }
195    
196     TString dummy = "";
197    
198     // read total number of trees, and check whether requested tree is in range
199     Int_t nc = 0;
200     while (!dummy.Contains("NTrees")) {
201     fin >> dummy;
202     nc++;
203     if (nc > 200) {
204     cout << endl;
205     cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
206     << fWfile << endl;
207     cout << "==> panic abort (please contact the TMVA authors)" << endl;
208     cout << endl;
209     exit(1);
210     }
211     }
212     fin >> dummy;
213     fNtrees = dummy.ReplaceAll("\"","").Atoi();
214     fin.close();
215     }
216     else{
217     void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
218     void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
219     void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
220     while(ch){
221     TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
222     if(nodeName=="Weights") {
223     TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
224     break;
225     }
226     ch = TMVA::gTools().xmlengine().GetNext(ch);
227     }
228     }
229     cout << "--- Found " << fNtrees << " decision trees in weight file" << endl;
230    
231     }
232    
233     //_______________________________________________________________________
234     void StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n,
235     Double_t x, Double_t y,
236     Double_t xscale, Double_t yscale, TString * vars)
237     {
238     // recursively puts an entries in the histogram for the node and its daughters
239     //
240     Float_t xsize=xscale*1.5;
241     Float_t ysize=yscale/3;
242     if (xsize>0.15) xsize=0.1; //xscale/2;
243     if (n->GetLeft() != NULL){
244     TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
245     a1->SetLineWidth(2);
246     a1->Draw();
247     DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
248     }
249     if (n->GetRight() != NULL){
250     TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
251     a1->SetLineWidth(2);
252     a1->Draw();
253     DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
254     }
255    
256     // TPaveText *t = new TPaveText(x-xscale/2,y-yscale/2,x+xscale/2,y+yscale/2, "NDC");
257     TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
258    
259     t->SetBorderSize(1);
260    
261     t->SetFillStyle(1);
262     if (n->GetNodeType() == 1) { t->SetFillColor( kSigColorF ); t->SetTextColor( kSigColorT ); }
263     else if (n->GetNodeType() == -1) { t->SetFillColor( kBkgColorF ); t->SetTextColor( kBkgColorT ); }
264     else if (n->GetNodeType() == 0) { t->SetFillColor( kIntColorF ); t->SetTextColor( kIntColorT ); }
265    
266     char buffer[25];
267     sprintf( buffer, "N=%f", n->GetNEvents() );
268     if (n->GetNEvents()>0) t->AddText(buffer);
269     sprintf( buffer, "S/(S+B)=%4.3f", n->GetPurity() );
270     t->AddText(buffer);
271    
272     if (n->GetNodeType() == 0){
273     if (n->GetCutType()){
274     t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
275     }else{
276     t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
277     }
278     }
279    
280     t->Draw();
281    
282     return;
283     }
284     TMVA::DecisionTree* StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
285     {
286     cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << endl;
287     TMVA::DecisionTree *d = new TMVA::DecisionTree();
288     if(!fWfile.EndsWith(".xml") ){
289     ifstream fin( fWfile );
290     if (!fin.good( )) { // file not found --> Error
291     cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
292     return 0;
293     }
294    
295     TString dummy = "";
296    
297     if (itree >= fNtrees) {
298     cout << "*** ERROR: requested decision tree: " << itree
299     << ", but number of trained trees only: " << fNtrees << endl;
300     return 0;
301     }
302    
303     // file header with name
304     while (!dummy.Contains("#VAR")) fin >> dummy;
305     fin >> dummy >> dummy >> dummy; // the rest of header line
306    
307     // number of variables
308     Int_t nVars;
309     fin >> dummy >> nVars;
310    
311     // variable mins and maxes
312     vars = new TString[nVars+1]; // last one is if "fisher cut criterium"
313     for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
314     vars[nVars]="FisherCrit";
315    
316     char buffer[20];
317     char line[256];
318     sprintf(buffer,"Tree %d",itree);
319    
320     while (!dummy.Contains(buffer)) {
321     fin.getline(line,256);
322     dummy = TString(line);
323     }
324    
325     d->Read(fin);
326    
327     fin.close();
328     }
329     else{
330     if (itree >= fNtrees) {
331     cout << "*** ERROR: requested decision tree: " << itree
332     << ", but number of trained trees only: " << fNtrees << endl;
333     return 0;
334     }
335     Int_t nVars;
336     void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
337     void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
338     void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
339     while(ch){
340     TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
341     if(nodeName=="Variables"){
342     TMVA::gTools().ReadAttr( ch, "NVar", nVars);
343     vars = new TString[nVars+1];
344     void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
345     for (Int_t i = 0; i < nVars; i++){
346     TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
347     varnode = TMVA::gTools().xmlengine().GetNext(varnode);
348     }
349     vars[nVars]="FisherCrit";
350     }
351     if(nodeName=="Weights") break;
352     ch = TMVA::gTools().xmlengine().GetNext(ch);
353     }
354     ch = TMVA::gTools().xmlengine().GetChild(ch);
355     for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
356     d->ReadXML(ch);
357     }
358     return d;
359     }
360    
361     //_______________________________________________________________________
362     void StatDialogBDT::DrawTree( Int_t itree )
363     {
364     TString *vars;
365     TMVA::DecisionTree* d = ReadTree( vars, itree );
366     if (d == 0) return;
367    
368     UInt_t depth = d->GetTotalTreeDepth();
369     Double_t ystep = 1.0/(depth + 1.0);
370    
371     cout << "--- Tree depth: " << depth << endl;
372    
373     TStyle* TMVAStyle = gROOT->GetStyle("Plain"); // our style is based on Plain
374     Int_t canvasColor = TMVAStyle->GetCanvasColor(); // backup
375    
376     TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
377     TString tbuffer = Form( "Decision Tree no.: %d", itree );
378     if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 );
379     else fCanvas->Clear();
380     fCanvas->Draw();
381    
382     DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
383    
384     // make the legend
385     Double_t yup=0.99;
386     Double_t ydown=yup-ystep/2.5;
387     Double_t dy= ystep/2.5 * 0.2;
388    
389     TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
390     whichTree->SetBorderSize(1);
391     whichTree->SetFillStyle(1);
392     whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
393     whichTree->AddText( tbuffer );
394     whichTree->Draw();
395    
396     TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
397     intermediate->SetBorderSize(1);
398     intermediate->SetFillStyle(1);
399     intermediate->SetFillColor( kIntColorF );
400     intermediate->AddText("Intermediate Nodes");
401     intermediate->SetTextColor( kIntColorT );
402     intermediate->Draw();
403    
404     ydown = ydown - ystep/2.5 -dy;
405     yup = yup - ystep/2.5 -dy;
406     TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
407     signalleaf->SetBorderSize(1);
408     signalleaf->SetFillStyle(1);
409     signalleaf->SetFillColor( kSigColorF );
410     signalleaf->AddText("Signal Leaf Nodes");
411     signalleaf->SetTextColor( kSigColorT );
412     signalleaf->Draw();
413    
414     ydown = ydown - ystep/2.5 -dy;
415     yup = yup - ystep/2.5 -dy;
416     TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
417     backgroundleaf->SetBorderSize(1);
418     backgroundleaf->SetFillStyle(1);
419     backgroundleaf->SetFillColor( kBkgColorF );
420    
421     backgroundleaf->AddText("Backgr. Leaf Nodes");
422     backgroundleaf->SetTextColor( kBkgColorT );
423     backgroundleaf->Draw();
424    
425     fCanvas->Update();
426     TString fname = Form("plots/%s_%i", fMethName.Data(), itree );
427     cout << "--- Creating image: " << fname << endl;
428     TMVAGlob::imgconv( fCanvas, fname );
429    
430     TMVAStyle->SetCanvasColor( canvasColor );
431     }
432    
433     // ========================================================================================
434    
435     static std::vector<TControlBar*> BDT_Global__cbar;
436    
437     // intermediate GUI
438     void BDT( const TString& fin = "TMVA.root" )
439     {
440     // --- read the available BDT weight files
441    
442     // destroy all open cavases
443     TMVAGlob::DestroyCanvases();
444    
445     // checks if file with name "fin" is already open, and if not opens one
446     TFile* file = TMVAGlob::OpenFile( fin );
447    
448     TDirectory* dir = file->GetDirectory( "Method_BDT" );
449     if (!dir) {
450     cout << "*** Error in macro \"BDT.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
451     return;
452     }
453    
454     // read all directories
455     TIter next( dir->GetListOfKeys() );
456     TKey *key(0);
457     std::vector<TString> methname;
458     std::vector<TString> path;
459     std::vector<TString> wfile;
460     while ((key = (TKey*)next())) {
461     TDirectory* mdir = dir->GetDirectory( key->GetName() );
462     if (!mdir) {
463     cout << "*** Error in macro \"BDT.C\": cannot find sub-directory: " << key->GetName()
464     << " in directory: " << dir->GetName() << endl;
465     return;
466     }
467    
468     // retrieve weight file name and path
469     TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
470     TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
471     if (!strPath || !strWFile) {
472     cout << "*** Error in macro \"BDT.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
473     cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
474     return;
475     }
476    
477     methname.push_back( key->GetName() );
478     path .push_back( strPath->GetString() );
479     wfile .push_back( strWFile->GetString() );
480     }
481    
482     // create the control bar
483     TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
484     BDT_Global__cbar.push_back(cbar);
485    
486     for (UInt_t im=0; im<path.size(); im++) {
487     TString fname = path[im];
488     if (fname[fname.Length()-1] != '/') fname += "/";
489     fname += wfile[im];
490     TString macro = Form( ".x BDT.C+\(0,\"%s\",\"%s\")", fname.Data(), methname[im].Data() );
491     cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
492     }
493    
494     // *** problems with this button in ROOT 5.19 ***
495     #if ROOT_VERSION_CODE < ROOT_VERSION(5,19,0)
496     cbar->AddButton( "Close", Form("BDT_DeleteTBar(%i)", BDT_Global__cbar.size()-1), "Close this control bar", "button" );
497     #endif
498     // **********************************************
499    
500     // set the style
501     cbar->SetTextColor("blue");
502    
503     // draw
504     cbar->Show();
505     }
506    
507     void BDT_DeleteTBar(int i)
508     {
509     // destroy all open canvases
510     StatDialogBDT::Delete();
511     TMVAGlob::DestroyCanvases();
512    
513     delete BDT_Global__cbar[i];
514     BDT_Global__cbar[i] = 0;
515     }
516    
517     // input: - No. of tree
518     // - the weight file from which the tree is read
519     void BDT( Int_t itree, TString wfile = "weights/TMVAnalysis_test_BDT.weights.txt", TString methName = "BDT", Bool_t useTMVAStyle = kTRUE )
520     {
521     // destroy possibly existing dialog windows and/or canvases
522     StatDialogBDT::Delete();
523     TMVAGlob::DestroyCanvases();
524    
525     // quick check if weight file exist
526     if(!wfile.EndsWith(".xml") ){
527     ifstream fin( wfile );
528     if (!fin.good( )) { // file not found --> Error
529     cout << "*** ERROR: Weight file: " << wfile << " does not exist" << endl;
530     return;
531     }
532     }
533     std::cout << "test1";
534     // set style and remove existing canvas'
535     TMVAGlob::Initialize( useTMVAStyle );
536    
537     StatDialogBDT* gGui = new StatDialogBDT( gClient->GetRoot(), wfile, methName, itree );
538    
539     gGui->DrawTree( itree );
540    
541     gGui->RaiseDialog();
542     }
543