ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/LJMet/MultivariateAnalysis/macros/network.C
Revision: 1.1
Committed: Thu Nov 20 22:34:50 2008 UTC (16 years, 5 months ago) by kukartse
Content type: text/plain
Branch: MAIN
CVS Tags: V00-03-01, ZMorph_BASE_20100408, gak040610_morphing, V00-02-02, gak011410, gak010310, ejterm2010_25nov2009, V00-02-01, V00-02-00, gak112409, CMSSW_22X_branch_base, segala101609, V00-01-15, V00-01-14, V00-01-13, V00-01-12, V00-01-11, V00-01-10, gak031009, gak030509, gak022309, gak021209, gak040209, gak012809, V00-01-09, V00-01-08, V00-01-07, V00-01-06, V00-01-05, V00-01-04, V00-00-07, V00-00-06, V00-00-05, V00-00-04, V00-01-03, V00-00-02, V00-00-01, HEAD
Branch point for: ZMorph-V00-03-01, CMSSW_22X_branch
Log Message:
created /macros with TMVA scripts in it

File Contents

# User Rev Content
1 kukartse 1.1 #include "tmvaglob.C"
2    
3     // this macro prints out a neural network generated by MethodMLP graphically
4     // @author: Matt Jachowski, jachowski@stanford.edu
5    
6     // input: - Input file (result from TMVA),
7     // - use of TMVA plotting TStyle
8     void network( TString fin = "TMVA.root", Bool_t useTMVAStyle = kTRUE )
9     {
10     // set style and remove existing canvas'
11     TMVAGlob::Initialize( useTMVAStyle );
12    
13     // checks if file with name "fin" is already open, and if not opens one
14     TFile* file = TMVAGlob::OpenFile( fin );
15    
16     TKey * mkey = TMVAGlob::FindMethod("MLP"); //(TDirectory*)gDirectory->Get("Method_MLP");
17     if (mkey==0) {
18     cout << "Could not locate directory MLP in file " << fin << endl;
19     return;
20     }
21     TDirectory *dir = (TDirectory *)mkey->ReadObj();
22     dir->cd();
23     TList titles;
24     UInt_t ni = TMVAGlob::GetListOfTitles( dir, titles );
25     if (ni==0) {
26     cout << "No titles found for Method_MLP" << endl;
27     return;
28     }
29     TIter nextTitle(&titles);
30     TKey *titkey;
31     TDirectory *titDir;
32     while ((titkey = TMVAGlob::NextKey(nextTitle,"TDirectory"))) {
33     titDir = (TDirectory *)titkey->ReadObj();
34     cout << "Drawing title: " << titDir->GetName() << endl;
35     draw_network(titDir);
36     }
37     }
38    
39     void draw_network(TDirectory* d)
40     {
41     Bool_t __PRINT_LOGO__ = kTRUE;
42    
43     // create canvas
44     TStyle *TMVAStyle = gROOT->GetStyle("Plain"); // our style is based on Plain
45     TMVAStyle->SetCanvasColor(37 + 100);
46    
47     TCanvas* c = new TCanvas( "c", "Neural Network Layout", 100, 0, 1000, 650 );
48    
49     TIter next = d->GetListOfKeys();
50     TKey *key;
51     TString hName = "weights_hist";
52     Int_t numHists = 0;
53    
54     // loop over all histograms with hName in name
55     while (key = (TKey*)next()) {
56     TClass *cl = gROOT->GetClass(key->GetClassName());
57     if (!cl->InheritsFrom("TH2F")) continue;
58     TH2F *h = (TH2F*)key->ReadObj();
59     if (TString(h->GetName()).Contains( hName ))
60     numHists++;
61     }
62    
63     // loop over all histograms with hName in name again
64     next.Reset();
65     Double_t maxWeight = 0;
66    
67     // find max weight
68     while (key = (TKey*)next()) {
69    
70     //cout << "Title: " << key->GetTitle() << endl;
71     TClass *cl = gROOT->GetClass(key->GetClassName());
72     if (!cl->InheritsFrom("TH2F")) continue;
73    
74     TH2F* h = (TH2F*)key->ReadObj();
75     if (TString(h->GetName()).Contains( hName )){
76    
77     Int_t n1 = h->GetNbinsX();
78     Int_t n2 = h->GetNbinsY();
79     for (Int_t i = 0; i < n1; i++) {
80     for (Int_t j = 0; j < n2; j++) {
81     Double_t weight = TMath::Abs(h->GetBinContent(i+1, j+1));
82     if (maxWeight < weight) maxWeight = weight;
83     }
84     }
85     }
86     }
87    
88     // draw network
89     next.Reset();
90     Int_t count = 0;
91     while (key = (TKey*)next()) {
92    
93     TClass *cl = gROOT->GetClass(key->GetClassName());
94     if (!cl->InheritsFrom("TH2F")) continue;
95    
96     TH2F* h = (TH2F*)key->ReadObj();
97     if (TString(h->GetName()).Contains( hName )){
98     draw_layer(c, h, count++, numHists+1, maxWeight);
99     }
100     }
101    
102     draw_layer_labels(numHists+1);
103    
104     // ============================================================
105     if (__PRINT_LOGO__) TMVAGlob::plot_logo();
106     // ============================================================
107    
108     c->Update();
109    
110     TString fname = "plots/network";
111     TMVAGlob::imgconv( c, fname );
112     }
113    
114     void draw_layer_labels(Int_t nLayers)
115     {
116     const Double_t LABEL_HEIGHT = 0.03;
117     const Double_t LABEL_WIDTH = 0.20;
118     Double_t effWidth = 0.8*(1.0-LABEL_WIDTH)/nLayers;
119     Double_t height = 0.8*LABEL_HEIGHT;
120     Double_t margY = LABEL_HEIGHT - height;
121    
122     for (Int_t i = 0; i < nLayers; i++) {
123     TString label = Form("Layer %i", i);
124     Double_t cx = i*(1.0-LABEL_WIDTH)/nLayers+1.0/(2.0*nLayers)+LABEL_WIDTH;
125     Double_t x1 = cx-0.8*effWidth/2.0;
126     Double_t x2 = cx+0.8*effWidth/2.0;
127     Double_t y1 = margY;
128     Double_t y2 = margY + height;
129    
130     TPaveLabel *p = new TPaveLabel(x1, y1, x2, y2, label+"", "br");
131     p->SetFillColor(gStyle->GetTitleFillColor());
132     p->SetFillStyle(1001);
133     p->Draw();
134     }
135     }
136    
137     void draw_input_labels(Int_t nInputs, Double_t* cy,
138     Double_t rad, Double_t layerWidth)
139     {
140     const Double_t LABEL_HEIGHT = 0.03;
141     const Double_t LABEL_WIDTH = 0.20;
142     Double_t width = LABEL_WIDTH + (layerWidth-4*rad);
143     Double_t margX = 0.01;
144     Double_t effHeight = 0.8*LABEL_HEIGHT;
145    
146     TString *varNames = get_var_names(nInputs);
147     TString input;
148    
149     for (Int_t i = 0; i < nInputs; i++) {
150     if (i != nInputs-1) input = varNames[i];
151     else input = "bias";
152     Double_t x1 = margX;
153     Double_t x2 = margX + width;
154     Double_t y1 = cy[i] - effHeight;
155     Double_t y2 = cy[i] + effHeight;
156    
157     TPaveLabel *p = new TPaveLabel(x1, y1, x2, y2, input+"", "br");
158     p->SetFillColor(gStyle->GetTitleFillColor());
159     p->SetFillStyle(1001);
160     p->Draw();
161     if (i == nInputs-1) p->SetTextColor(9);
162     }
163    
164     delete[] varNames;
165     }
166    
167     TString* get_var_names(Int_t nVars)
168     {
169     TString fname = "weights/MVAnalysis_MLP.weights.txt";
170     ifstream fin( fname );
171     if (!fin.good( )) { // file not found --> Error
172     cout << "Error opening " << fname << endl;
173     exit(1);
174     }
175    
176     Int_t idummy;
177     Float_t fdummy;
178     TString dummy = "";
179    
180     // file header with name
181     while (!dummy.Contains("#VAR")) fin >> dummy;
182     fin >> dummy >> dummy >> dummy; // the rest of header line
183    
184     // number of variables
185     fin >> dummy >> idummy;
186     // at this point, we should have idummy == nVars
187    
188     // variable mins and maxes
189     TString* vars = new TString[nVars];
190     for (Int_t i = 0; i < idummy; i++) fin >> vars[i] >> dummy >> dummy >> dummy;
191    
192     fin.close();
193    
194     return vars;
195     }
196    
197     void draw_activation(TCanvas* c, Double_t cx, Double_t cy,
198     Double_t radx, Double_t rady, Int_t whichActivation)
199     {
200     TImage *activation = NULL;
201    
202     switch (whichActivation) {
203     case 0:
204     activation = TImage::Open("sigmoid-small.png");
205     break;
206     case 1:
207     activation = TImage::Open("line-small.png");
208     break;
209     default:
210     cout << "Activation index " << whichActivation << " is not known." << endl;
211     cout << "You messed up or you need to modify network.C to introduce a new "
212     << "activation function (and image) corresponding to this index" << endl;
213     }
214    
215     if (activation == NULL) {
216     cout << "Could not create an image... exit" << endl;
217     return;
218     }
219    
220     activation->SetConstRatio(kFALSE);
221    
222     radx *= 0.7;
223     rady *= 0.7;
224     TString name = Form("activation%f%f", cx, cy);
225     TPad* p = new TPad(name+"", name+"", cx-radx, cy-rady, cx+radx, cy+rady);
226    
227     p->Draw();
228     p->cd();
229    
230     activation->Draw();
231     c->cd();
232     }
233    
234     void draw_layer(TCanvas* c, TH2F* h, Int_t iHist,
235     Int_t nLayers, Double_t maxWeight)
236     {
237     const Double_t MAX_NEURONS_NICE = 12;
238     const Double_t LABEL_HEIGHT = 0.03;
239     const Double_t LABEL_WIDTH = 0.20;
240     Double_t ratio = ((Double_t)(c->GetWindowHeight())) / c->GetWindowWidth();
241     Double_t rad, cx1, *cy1, cx2, *cy2;
242    
243     // this is the smallest radius that will still display the activation images
244     rad = 0.04*650/c->GetWindowHeight();
245    
246     Int_t nNeurons1 = h->GetNbinsX();
247     cx1 = iHist*(1.0-LABEL_WIDTH)/nLayers + 1.0/(2.0*nLayers) + LABEL_WIDTH;
248     cy1 = new Double_t[nNeurons1];
249    
250     Int_t nNeurons2 = h->GetNbinsY();
251     cx2 = (iHist+1)*(1.0-LABEL_WIDTH)/nLayers + 1.0/(2.0*nLayers) + LABEL_WIDTH;
252     cy2 = new Double_t[nNeurons2];
253    
254     Double_t effRad1 = rad;
255     if (nNeurons1 > MAX_NEURONS_NICE)
256     effRad1 = 0.8*(1.0-LABEL_HEIGHT)/(2.0*nNeurons1);
257    
258    
259     for (Int_t i = 0; i < nNeurons1; i++) {
260     cy1[nNeurons1-i-1] = i*(1.0-LABEL_HEIGHT)/nNeurons1 +
261     1.0/(2.0*nNeurons1) + LABEL_HEIGHT;
262    
263     if (iHist == 0) {
264    
265     TEllipse *ellipse
266     = new TEllipse(cx1, cy1[nNeurons1-i-1],
267     effRad1*ratio, effRad1, 0, 360, 0);
268     ellipse->SetFillColor(19+150);
269     ellipse->SetFillStyle(1001);
270     ellipse->Draw();
271    
272     if (i == 0) ellipse->SetLineColor(9);
273    
274     if (nNeurons1 > MAX_NEURONS_NICE) continue;
275    
276     Int_t whichActivation = 0;
277     if (iHist==0 || iHist==nLayers-1 || i==0) whichActivation = 1;
278     draw_activation(c, cx1, cy1[nNeurons1-i-1],
279     rad*ratio, rad, whichActivation);
280     }
281     }
282    
283     if (iHist == 0) draw_input_labels(nNeurons1, cy1, rad, (1.0-LABEL_WIDTH)/nLayers);
284    
285     Double_t effRad2 = rad;
286     if (nNeurons2 > MAX_NEURONS_NICE)
287     effRad2 = 0.8*(1.0-LABEL_HEIGHT)/(2.0*nNeurons2);
288    
289     for (Int_t i = 0; i < nNeurons2; i++) {
290     cy2[nNeurons2-i-1] = i*(1.0-LABEL_HEIGHT)/nNeurons2 + 1.0/(2.0*nNeurons2) + LABEL_HEIGHT;
291    
292     TEllipse *ellipse =
293     new TEllipse(cx2, cy2[nNeurons2-i-1], effRad2*ratio, effRad2, 0, 360, 0);
294     ellipse->SetFillColor(19+150);
295     ellipse->SetFillStyle(1001);
296     ellipse->Draw();
297    
298     if (i == 0 && nNeurons2 > 1) ellipse->SetLineColor(9);
299    
300     if (nNeurons2 > MAX_NEURONS_NICE) continue;
301    
302     Int_t whichActivation = 0;
303     if (iHist+1==0 || iHist+1==nLayers-1 || i==0) whichActivation = 1;
304     draw_activation(c, cx2, cy2[nNeurons2-i-1], rad*ratio, rad, whichActivation);
305     }
306    
307     for (Int_t i = 0; i < nNeurons1; i++) {
308     for (Int_t j = 0; j < nNeurons2; j++) {
309     draw_synapse(cx1, cy1[i], cx2, cy2[j], effRad1*ratio, effRad2*ratio,
310     h->GetBinContent(i+1, j+1)/maxWeight);
311     }
312     }
313    
314     delete[] cy1;
315     delete[] cy2;
316     }
317    
318     void draw_synapse(Double_t cx1, Double_t cy1, Double_t cx2, Double_t cy2,
319     Double_t rad1, Double_t rad2, Double_t weightNormed)
320     {
321     const Double_t TIP_SIZE = 0.01;
322     const Double_t MAX_WEIGHT = 8;
323     const Double_t MAX_COLOR = 100; // red
324     const Double_t MIN_COLOR = 60; // blue
325    
326     if (weightNormed == 0) return;
327    
328     // gStyle->SetPalette(100, NULL);
329    
330     TArrow *arrow = new TArrow(cx1+rad1, cy1, cx2-rad2, cy2, TIP_SIZE, ">");
331     arrow->SetFillColor(1);
332     arrow->SetFillStyle(1001);
333     arrow->SetLineWidth((Int_t)(TMath::Abs(weightNormed)*MAX_WEIGHT+0.5));
334     arrow->SetLineColor((Int_t)((weightNormed+1.0)/2.0*(MAX_COLOR-MIN_COLOR)+MIN_COLOR+0.5));
335     arrow->Draw();
336     }