ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/MitHzz4l/Angles/src/trainAngles.cc
Revision: 1.6
Committed: Thu Feb 7 08:58:30 2013 UTC (12 years, 3 months ago) by dkralph
Content type: text/plain
Branch: MAIN
CVS Tags: HEAD
Changes since 1.5: +8 -3 lines
Log Message:
*** empty log message ***

File Contents

# User Rev Content
1 dkralph 1.1 #include "TFile.h"
2     #include "TNtuple.h"
3     #include "TMath.h"
4    
5     #include "TMVA/Factory.h"
6     #include "TMVA/Tools.h"
7     #include "TMVA/Config.h"
8    
9     #include <unistd.h>
10     #include <iostream>
11     #include <fstream>
12     #include <string.h>
13     #include <map>
14     #include <cassert>
15    
16 dkralph 1.2 #include "KinematicsStruct.h"
17     #include "HistHeaders.h"
18 dkralph 1.4 #include "filestuff.h"
19    
20     #include "ZZMatrixElement/MELA/interface/Mela.h"
21     #include "ZZMatrixElement/MELA/interface/PseudoMELA.h"
22     #include "ZZMatrixElement/MELA/src/computeAngles.h"
23 dkralph 1.2
24 dkralph 1.1 using namespace std;
25    
26 dkralph 1.4 void setTrainingVal(TString name, double *val, filestuff *fs);
27 dkralph 1.1 void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi);
28    
29     int main(int argc, char** argv) {
30    
31     if( argc < 7 ) {
32     cerr << "usage: makeBDTWeights.exe <config> <outputrootfile> <weightname> <varstring> <mH> [wL] [wH] [mH_lo] [mH_hi]" << endl;
33     cerr << "\t... where [wL,wH] are the requested mass window and [mH_lo,mH_hi] are used for interpolation" << endl;
34     return 1;
35     }
36     cout << "argc: " << argc << endl;
37     for(int iarg=0; iarg<argc; iarg++)
38     cout << "argv[" << iarg << "]: " << argv[iarg] << endl;
39    
40     char * config = argv[1];
41     cout << "config: " << config << endl;
42    
43     // TString ntupledir;
44 dkralph 1.4 vector<TString> classes,fileNames,types,indirs,dsets,wgtStrs;
45 dkralph 1.1 vector<double> fileFracs,fileMasses;
46     vector<TFile*> inFiles;
47     vector<TTree*> inTrees;
48 dkralph 1.4 vector<filestuff*> fsv;
49     bool multiclass(false);
50     bool multisigs(false);
51     bool noFakeTest(false);
52     bool useMela(false);
53 dkralph 1.3 TString NTrees(""),NNodesMax(""),Shrinkage(""),nCuts(""),NormMode("");
54 dkralph 1.2 int nMax=1000000;
55 dkralph 1.1 ifstream ifs(config);
56     assert(ifs.is_open());
57     string line;
58     while(getline(ifs,line)) {
59     if(line[0]=='#') continue;
60     stringstream ss(line);
61     if(line[0]=='^') {
62     TString dummy;
63     // if(TString(line).Contains("^ntupledir")) ss >> dummy >> ntupledir;
64 dkralph 1.4 if(TString(line).Contains("^multiclass")){ multiclass = true; cout << "multiclass" << multiclass << endl; }
65     if(TString(line).Contains("^multisigs" )){ multisigs = true; cout << "multisigs" << multisigs << endl; }
66     if(TString(line).Contains("^noFakeTest")){ noFakeTest = true; cout << "noFakeTest" << noFakeTest << endl; }
67     if(TString(line).Contains("^NTrees" )){ ss >> dummy >> NTrees; cout << "NTrees" << NTrees << endl; }
68     if(TString(line).Contains("^NNodesMax" )){ ss >> dummy >> NNodesMax; cout << "NNodesMax" << NNodesMax << endl; }
69     if(TString(line).Contains("^Shrinkage" )){ ss >> dummy >> Shrinkage; cout << "Shrinkage" << Shrinkage << endl; }
70     if(TString(line).Contains("^nCuts" )){ ss >> dummy >> nCuts; cout << "nCuts" << nCuts << endl; }
71     if(TString(line).Contains("^NormMode" )){ ss >> dummy >> NormMode; cout << "NormMode" << NormMode << endl; }
72     if(TString(line).Contains("^nMax" )){ ss >> dummy >> nMax; cout << "nMax" << nMax << endl; }
73     if(TString(line).Contains("^useMela" )){ useMela = true; cout << "useMela" << useMela << endl; }
74 dkralph 1.1 continue;
75     }
76    
77 dkralph 1.4 TString cls,indir,type,dset,wgtStr;
78 dkralph 1.1 double frac,fileMass;
79 dkralph 1.4 bool isdata;
80     ss >> cls >> type >> frac >> wgtStr >> indir >> dset >> fileMass >> isdata;
81 dkralph 1.3 // assert(cls=="sig" || cls=="sig2" || cls=="sig3" || cls=="bkg");
82 dkralph 1.1 classes.push_back(cls);
83 dkralph 1.3 types.push_back(type);
84 dkralph 1.4 indirs.push_back(indir);
85     dsets.push_back(dset);
86     TString fname(indir+"/"+dset+"/merged.root");
87     fsv.push_back(new filestuff(dset, fname, dset, isdata, 2012));
88 dkralph 1.1 fileNames.push_back(/*ntupledir+"/"+*/fname);
89     fileFracs.push_back(frac);
90 dkralph 1.4 wgtStrs.push_back(wgtStr);
91 dkralph 1.1 fileMasses.push_back(fileMass);
92     cout << " pushing back: " << cls << " " << setw(12) << frac << " " << fname << " " << fileMass << endl;
93     inFiles.push_back(TFile::Open(fileNames.back()));
94     if(!inFiles.back()->IsOpen()) {
95     cout << fileNames.back() << " not open!" << endl;
96     assert(0);
97     }
98     inTrees.push_back((TTree*)inFiles.back()->Get("zznt"));
99     assert(inTrees.back());
100     }
101     ifs.close();
102    
103     char * outputrootfile = argv[2];
104     char * weightname = argv[3];
105     char * varstring = argv[4];
106     float mH = strtof(argv[5],NULL);
107     float wH=900, wL=100;
108     if( argc > 6 ) {
109     wL = strtof(argv[6], NULL);
110     wH = strtof(argv[7], NULL);
111     }
112     float mH_lo=0,mH_hi=0;
113     float frac_lo=1,frac_hi=1;
114     map<double,double> mH_fracs;
115     mH_fracs[-1] = 1; // this means zz
116     assert( argc > 8);
117     mH_lo = strtof(argv[8],NULL);
118     mH_hi = strtof(argv[9],NULL);
119     if(mH_lo!=0 && mH_hi!=0) {
120     // set the fraction to use for each input file, with more events from the closer mass point
121     // NOTE: we only use these fractions if the fraction from the config file is -1
122     setFractions(mH,mH_lo,mH_hi,frac_lo,frac_hi);
123     mH_fracs[mH_lo] = frac_lo;
124     mH_fracs[mH_hi] = frac_hi;
125     } else {
126     mH_fracs[mH] = 1;
127     }
128    
129     cerr
130     << "name: " << weightname << endl
131     << "output: " << outputrootfile << endl
132     << "variables: " << varstring << endl
133     << "mH: " << mH << endl
134     << "mass window:" << endl
135     << " wL: " << wL << endl
136     << " wH: " << wH << endl
137     << "masses from which to interpolate:" << endl
138     << " mH_lo: " << mH_lo << endl
139     << " mH_hi: " << mH_hi << endl
140     << "weights for interpolation:" << endl
141     << " frac_lo: " << frac_lo << endl
142     << " frac_hi: " << frac_hi << endl
143     << endl;
144    
145 dkralph 1.4 // get the list of variables we're going to use
146     vector<double> varVals;
147     TString varStr(varstring);
148     varStr.ReplaceAll(":"," ");
149     vector<TString> vars,specs;
150     stringstream ss(varStr.Data());
151     while(!ss.eof()) {
152     TString var;
153     ss >> var;
154     vars.push_back(var);
155     cout << "VAR: " << var << endl;
156     varVals.push_back(double(0));
157 dkralph 1.1 }
158 dkralph 1.4 specs.push_back("m4l"); varVals.push_back(double(0));
159     specs.push_back("run"); varVals.push_back(double(0));
160     specs.push_back("lumi"); varVals.push_back(double(0));
161     specs.push_back("evt"); varVals.push_back(double(0));
162     assert(varVals.size() == (vars.size() + specs.size()));
163 dkralph 1.1
164 dkralph 1.4 TFile *outputFile = new TFile(outputrootfile, "RECREATE");
165 dkralph 1.1 TMVA::Tools::Instance();
166     TString facStr("!V:!Silent:!DrawProgressBar");
167     if(multiclass) facStr = facStr+":AnalysisType=multiclass";
168 dkralph 1.4 TMVA::Factory* factory = new TMVA::Factory(weightname, outputFile, facStr);
169     for(unsigned ivar=0; ivar<vars.size(); ivar++) factory->AddVariable( vars[ivar], 'F');
170     for(unsigned ispec=0; ispec<specs.size(); ispec++) factory->AddSpectator(specs[ispec], 'F');
171    
172 dkralph 1.1 char buf[256];
173 dkralph 1.4 sprintf( buf, "(m4l>%f && m4l<%f)", (float)wL, (float)wH );
174 dkralph 1.1 TString cutStr(buf);
175     cout << "cutStr: " << cutStr << endl;
176    
177 dkralph 1.4 bool useWgts(true);
178     vector<vector<TString> > files;
179     vector<double> passFracs,nPassCutv,baseWgts;
180     cout << setw(13) << "class" << setw(6) << "type" << setw(12) << " nPass " << setw(8) << "nTot" << setw(10) << " frac" << setw(12) << "baseWgt" << endl;
181 dkralph 1.1 for(unsigned ifile=0; ifile<fileNames.size(); ifile++) {
182 dkralph 1.4 if(fileFracs[ifile] == -1) fileFracs[ifile] = mH_fracs[fileMasses[ifile]];
183     double nPass(0),nTot(0),passFrac(0);
184     if(useWgts) {
185     TH1D hPass("hPass","",100,0,1000);
186     TH1D hTot("hTot","",100,0,1000);
187     TString wgtStr(wgtStrs[ifile].Contains("wInterf") ? "1" : wgtStrs[ifile]); // tree doesn't have a wInterf leaf
188     inTrees[ifile]->Draw("m4l>>hPass",wgtStr+"*"+cutStr);
189     inTrees[ifile]->Draw("m4l>>hTot",wgtStr);
190     passFrac = hPass.Integral() / hTot.Integral();
191     nPass = hPass.Integral();
192     nTot = hTot.Integral();
193     } else {
194     nPass = double(inTrees[ifile]->GetEntries(cutStr));
195     passFrac = nPass / inTrees[ifile]->GetEntries();
196     nTot = inTrees[ifile]->GetEntries();
197     }
198     passFracs.push_back(passFrac);
199     nPassCutv.push_back(nPass);
200     double baseWgt(fileFracs[ifile] / nPass);
201     // if(fsv[ifile]->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest)
202     // baseWgt /= 2; // if we're putting all the fakes into training, divide weights by two
203     baseWgts.push_back(baseWgt);
204     cout << setw(13) << classes[ifile] << "(" << setw(5) << types[ifile] << "): " << setw(10) << nPass << " / " << setw(12) << nTot
205     << setw(6) << setprecision(5) << fileFracs[ifile] << setw(15) << baseWgt
206     << setw(30) << dsets[ifile] << " " << indirs[ifile]
207     << endl;
208 dkralph 1.3 }
209    
210 dkralph 1.4 Mela *mela(0);
211     if(useMela) mela = new Mela(false, 8);
212 dkralph 1.3
213 dkralph 1.4 for(unsigned ifs=0; ifs<fsv.size(); ifs++) {
214     filestuff *fs = fsv[ifs];
215     cout << "starting: " << fs->dataset_ << endl;
216     double tot(0);
217     for(unsigned ientry=0; ientry<fs->getentries("zznt"); ientry++) {
218     fs->getentry(ientry,"kinematics","zznt");
219     if(fs->kine->m4l < wL || fs->kine->m4l > wH) continue;
220     fs->getentry(ientry,"","zznt");
221 dkralph 1.6 // if(ientry<100)
222     // cout << fs->ji->nJets << endl; // wasn't filled correctly for some reason
223 dkralph 1.4
224     for(unsigned ivar=0; ivar<vars.size(); ivar++)
225     setTrainingVal(vars[ivar], &varVals[ivar], fs);
226     for(unsigned ispec=vars.size(); ispec<varVals.size(); ispec++)
227     setTrainingVal(specs[ispec-vars.size()], &varVals[ispec], fs);
228    
229     // train on even events
230     // if noFakeTest train on *all* the lljj events since we don't have many to start with
231     double wgt(baseWgts[ifs]);
232     if(wgtStrs[ifs] == "w") wgt *= fs->weights->w;
233     else if(wgtStrs[ifs] == "1") wgt *= 1;
234 dkralph 1.6 // else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs, mela, fs->dataset_.Contains("-XXXdkr"), mH);
235     else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs, mela);
236 dkralph 1.4 else assert(0);
237     if(wgt<.000000001 || wgt>5) { cout << "bad wgt?" << wgt << endl;}
238 dkralph 1.6 // if(fs->dataset_.Contains("-dkr") // && ientry<100
239     // )
240     // // cout << "test wgt: " << getInterferenceWeight(fs, mela, fs->dataset_.Contains("-XXXdkr"), mH) << endl;
241     // // cout << "test wgt: " << getInterferenceWeight(fs, mela) << endl;
242 dkralph 1.4
243     if((ientry%2)==0 || (fs->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest))
244     factory->AddEvent(classes[ifs], TMVA::Types::kTraining, varVals, wgt);
245     if((ientry%2)!=0)
246     factory->AddEvent(classes[ifs], TMVA::Types::kTesting, varVals, wgt);
247 dkralph 1.3
248 dkralph 1.4 tot += wgt;
249     }
250     cout << " " << tot << endl;
251 dkralph 1.3 }
252 dkralph 1.1
253     if(NTrees=="") NTrees = "10000";
254     if(NNodesMax=="") NNodesMax = "10";
255     if(Shrinkage=="") Shrinkage = "0.05";
256     if(nCuts=="") nCuts = "20";
257 dkralph 1.4 if(NormMode=="") NormMode = "EqualNumEvents";
258     assert(NormMode=="NumEvents" || NormMode=="EqualNumEvents");
259 dkralph 1.3 factory->PrepareTrainingAndTestTree( "", "SplitMode=Random:V:NormMode="+NormMode);
260 dkralph 1.1 factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:V:NTrees="+NTrees
261     +":NNodesMax="+NNodesMax
262     +":BoostType=Grad:Shrinkage="+Shrinkage
263     +":UseBaggedGrad:GradBaggingFraction=0.50:nCuts="+nCuts
264 dkralph 1.3 // +":nTrain_Signal=100:nTest_Signal=100:nTrain_Background=100:nTest_Background=100:nTrain_Signal2=100:nTest_Signal2=100"
265 dkralph 1.1 +":IgnoreNegWeights");
266    
267     factory->TrainAllMethods();
268     factory->TestAllMethods();
269     factory->EvaluateAllMethods();
270    
271     std::cerr << "closing ... " << std::endl;
272     outputFile->Close();
273     }
274     //----------------------------------------------------------------------------------------
275     void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi)
276     {
277     double delta = mH_hi - mH_lo;
278     frac_lo = (mH_hi - mH)/delta;
279     frac_hi = (mH - mH_lo)/delta;
280     }
281     //----------------------------------------------------------------------------------------
282 dkralph 1.4 void setTrainingVal(TString name, double *val, filestuff *fs)
283     {
284     if(name=="costheta1") *val = fs->angles->costheta1;
285     else if(name=="costheta2") *val = fs->angles->costheta2;
286     else if(name=="costhetastar") *val = fs->angles->costhetastar;
287     else if(name=="Phi") *val = fs->angles->Phi;
288     else if(name=="Phi1") *val = fs->angles->Phi1;
289     else if(name=="mZ1") *val = fs->kine->mZ1;
290     else if(name=="mZ2") *val = fs->kine->mZ2;
291     // pt variables
292     else if(name=="ZZpt") *val = fs->kine->ZZpt/fs->kine->m4l;
293     else if(name=="ZZdotZ1") *val = fs->kine->ZZdotZ1/(fs->kine->m4l*fs->kine->mZ1);
294     else if(name=="ZZdotZ2") *val = fs->kine->ZZdotZ2/(fs->kine->m4l*fs->kine->mZ2);
295     else if(name=="ZZptCosDphiZ1pt") *val = fs->kine->ZZptZ1ptCosDphi;
296     else if(name=="ZZptCosDphiZ2pt") *val = fs->kine->ZZptZ2ptCosDphi;
297     else if(name=="Z1pt") *val = fs->kine->Z1pt/fs->kine->m4l;
298     else if(name=="Z2pt") *val = fs->kine->Z2pt/fs->kine->m4l;
299     else if(name=="ZZy") *val = fs->kine->ZZy;
300     // jet variables
301     else if(name=="nJets" ) *val = fs->vk->getval("nJets");
302     else if(name=="mjj" ) *val = fs->vk->getval("mjj");
303     else if(name=="dEta" ) *val = fs->vk->getval("dEta");
304     else if(name=="etaProd" ) *val = fs->vk->getval("etaProd");
305     else if(name=="dphiJ1HiPtZ" ) *val = fs->vk->getval("dphiJ1HiPtZ");
306     else if(name=="dphiJ1LoPtZ" ) *val = fs->vk->getval("dphiJ1LoPtZ");
307     else if(name=="dEtaJ1HiPtZ" ) *val = fs->vk->getval("dEtaJ1HiPtZ");
308     else if(name=="dEtaJ1LoPtZ" ) *val = fs->vk->getval("dEtaJ1LoPtZ");
309     else if(name=="J1dotHiPtZ" ) *val = fs->vk->getval("J1dotHiPtZ");
310     else if(name=="J1dotLoPtZ" ) *val = fs->vk->getval("J1dotLoPtZ");
311     else if(name=="dphiJ2HiPtZ" ) *val = fs->vk->getval("dphiJ2HiPtZ");
312     else if(name=="dphiJ2LoPtZ" ) *val = fs->vk->getval("dphiJ2LoPtZ");
313     else if(name=="dEtaJ2HiPtZ" ) *val = fs->vk->getval("dEtaJ2HiPtZ");
314     else if(name=="dEtaJ2LoPtZ" ) *val = fs->vk->getval("dEtaJ2LoPtZ");
315     else if(name=="J2dotHiPtZ" ) *val = fs->vk->getval("J2dotHiPtZ");
316     else if(name=="J2dotLoPtZ" ) *val = fs->vk->getval("J2dotLoPtZ");
317     // spectators
318     else if(name=="m4l") *val = fs->kine->m4l;
319     else if(name=="run") *val = fs->info->run;
320     else if(name=="lumi") *val = fs->info->lumi;
321     else if(name=="evt") *val = fs->info->evt;
322     else { cout << name << " not found!" << endl; assert(0); }
323     }