ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/MitHzz4l/Angles/src/trainAngles.cc
Revision: 1.6
Committed: Thu Feb 7 08:58:30 2013 UTC (12 years, 3 months ago) by dkralph
Content type: text/plain
Branch: MAIN
CVS Tags: HEAD
Changes since 1.5: +8 -3 lines
Log Message:
*** empty log message ***

File Contents

# Content
1 #include "TFile.h"
2 #include "TNtuple.h"
3 #include "TMath.h"
4
5 #include "TMVA/Factory.h"
6 #include "TMVA/Tools.h"
7 #include "TMVA/Config.h"
8
9 #include <unistd.h>
10 #include <iostream>
11 #include <fstream>
12 #include <string.h>
13 #include <map>
14 #include <cassert>
15
16 #include "KinematicsStruct.h"
17 #include "HistHeaders.h"
18 #include "filestuff.h"
19
20 #include "ZZMatrixElement/MELA/interface/Mela.h"
21 #include "ZZMatrixElement/MELA/interface/PseudoMELA.h"
22 #include "ZZMatrixElement/MELA/src/computeAngles.h"
23
24 using namespace std;
25
26 void setTrainingVal(TString name, double *val, filestuff *fs);
27 void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi);
28
29 int main(int argc, char** argv) {
30
31 if( argc < 7 ) {
32 cerr << "usage: makeBDTWeights.exe <config> <outputrootfile> <weightname> <varstring> <mH> [wL] [wH] [mH_lo] [mH_hi]" << endl;
33 cerr << "\t... where [wL,wH] are the requested mass window and [mH_lo,mH_hi] are used for interpolation" << endl;
34 return 1;
35 }
36 cout << "argc: " << argc << endl;
37 for(int iarg=0; iarg<argc; iarg++)
38 cout << "argv[" << iarg << "]: " << argv[iarg] << endl;
39
40 char * config = argv[1];
41 cout << "config: " << config << endl;
42
43 // TString ntupledir;
44 vector<TString> classes,fileNames,types,indirs,dsets,wgtStrs;
45 vector<double> fileFracs,fileMasses;
46 vector<TFile*> inFiles;
47 vector<TTree*> inTrees;
48 vector<filestuff*> fsv;
49 bool multiclass(false);
50 bool multisigs(false);
51 bool noFakeTest(false);
52 bool useMela(false);
53 TString NTrees(""),NNodesMax(""),Shrinkage(""),nCuts(""),NormMode("");
54 int nMax=1000000;
55 ifstream ifs(config);
56 assert(ifs.is_open());
57 string line;
58 while(getline(ifs,line)) {
59 if(line[0]=='#') continue;
60 stringstream ss(line);
61 if(line[0]=='^') {
62 TString dummy;
63 // if(TString(line).Contains("^ntupledir")) ss >> dummy >> ntupledir;
64 if(TString(line).Contains("^multiclass")){ multiclass = true; cout << "multiclass" << multiclass << endl; }
65 if(TString(line).Contains("^multisigs" )){ multisigs = true; cout << "multisigs" << multisigs << endl; }
66 if(TString(line).Contains("^noFakeTest")){ noFakeTest = true; cout << "noFakeTest" << noFakeTest << endl; }
67 if(TString(line).Contains("^NTrees" )){ ss >> dummy >> NTrees; cout << "NTrees" << NTrees << endl; }
68 if(TString(line).Contains("^NNodesMax" )){ ss >> dummy >> NNodesMax; cout << "NNodesMax" << NNodesMax << endl; }
69 if(TString(line).Contains("^Shrinkage" )){ ss >> dummy >> Shrinkage; cout << "Shrinkage" << Shrinkage << endl; }
70 if(TString(line).Contains("^nCuts" )){ ss >> dummy >> nCuts; cout << "nCuts" << nCuts << endl; }
71 if(TString(line).Contains("^NormMode" )){ ss >> dummy >> NormMode; cout << "NormMode" << NormMode << endl; }
72 if(TString(line).Contains("^nMax" )){ ss >> dummy >> nMax; cout << "nMax" << nMax << endl; }
73 if(TString(line).Contains("^useMela" )){ useMela = true; cout << "useMela" << useMela << endl; }
74 continue;
75 }
76
77 TString cls,indir,type,dset,wgtStr;
78 double frac,fileMass;
79 bool isdata;
80 ss >> cls >> type >> frac >> wgtStr >> indir >> dset >> fileMass >> isdata;
81 // assert(cls=="sig" || cls=="sig2" || cls=="sig3" || cls=="bkg");
82 classes.push_back(cls);
83 types.push_back(type);
84 indirs.push_back(indir);
85 dsets.push_back(dset);
86 TString fname(indir+"/"+dset+"/merged.root");
87 fsv.push_back(new filestuff(dset, fname, dset, isdata, 2012));
88 fileNames.push_back(/*ntupledir+"/"+*/fname);
89 fileFracs.push_back(frac);
90 wgtStrs.push_back(wgtStr);
91 fileMasses.push_back(fileMass);
92 cout << " pushing back: " << cls << " " << setw(12) << frac << " " << fname << " " << fileMass << endl;
93 inFiles.push_back(TFile::Open(fileNames.back()));
94 if(!inFiles.back()->IsOpen()) {
95 cout << fileNames.back() << " not open!" << endl;
96 assert(0);
97 }
98 inTrees.push_back((TTree*)inFiles.back()->Get("zznt"));
99 assert(inTrees.back());
100 }
101 ifs.close();
102
103 char * outputrootfile = argv[2];
104 char * weightname = argv[3];
105 char * varstring = argv[4];
106 float mH = strtof(argv[5],NULL);
107 float wH=900, wL=100;
108 if( argc > 6 ) {
109 wL = strtof(argv[6], NULL);
110 wH = strtof(argv[7], NULL);
111 }
112 float mH_lo=0,mH_hi=0;
113 float frac_lo=1,frac_hi=1;
114 map<double,double> mH_fracs;
115 mH_fracs[-1] = 1; // this means zz
116 assert( argc > 8);
117 mH_lo = strtof(argv[8],NULL);
118 mH_hi = strtof(argv[9],NULL);
119 if(mH_lo!=0 && mH_hi!=0) {
120 // set the fraction to use for each input file, with more events from the closer mass point
121 // NOTE: we only use these fractions if the fraction from the config file is -1
122 setFractions(mH,mH_lo,mH_hi,frac_lo,frac_hi);
123 mH_fracs[mH_lo] = frac_lo;
124 mH_fracs[mH_hi] = frac_hi;
125 } else {
126 mH_fracs[mH] = 1;
127 }
128
129 cerr
130 << "name: " << weightname << endl
131 << "output: " << outputrootfile << endl
132 << "variables: " << varstring << endl
133 << "mH: " << mH << endl
134 << "mass window:" << endl
135 << " wL: " << wL << endl
136 << " wH: " << wH << endl
137 << "masses from which to interpolate:" << endl
138 << " mH_lo: " << mH_lo << endl
139 << " mH_hi: " << mH_hi << endl
140 << "weights for interpolation:" << endl
141 << " frac_lo: " << frac_lo << endl
142 << " frac_hi: " << frac_hi << endl
143 << endl;
144
145 // get the list of variables we're going to use
146 vector<double> varVals;
147 TString varStr(varstring);
148 varStr.ReplaceAll(":"," ");
149 vector<TString> vars,specs;
150 stringstream ss(varStr.Data());
151 while(!ss.eof()) {
152 TString var;
153 ss >> var;
154 vars.push_back(var);
155 cout << "VAR: " << var << endl;
156 varVals.push_back(double(0));
157 }
158 specs.push_back("m4l"); varVals.push_back(double(0));
159 specs.push_back("run"); varVals.push_back(double(0));
160 specs.push_back("lumi"); varVals.push_back(double(0));
161 specs.push_back("evt"); varVals.push_back(double(0));
162 assert(varVals.size() == (vars.size() + specs.size()));
163
164 TFile *outputFile = new TFile(outputrootfile, "RECREATE");
165 TMVA::Tools::Instance();
166 TString facStr("!V:!Silent:!DrawProgressBar");
167 if(multiclass) facStr = facStr+":AnalysisType=multiclass";
168 TMVA::Factory* factory = new TMVA::Factory(weightname, outputFile, facStr);
169 for(unsigned ivar=0; ivar<vars.size(); ivar++) factory->AddVariable( vars[ivar], 'F');
170 for(unsigned ispec=0; ispec<specs.size(); ispec++) factory->AddSpectator(specs[ispec], 'F');
171
172 char buf[256];
173 sprintf( buf, "(m4l>%f && m4l<%f)", (float)wL, (float)wH );
174 TString cutStr(buf);
175 cout << "cutStr: " << cutStr << endl;
176
177 bool useWgts(true);
178 vector<vector<TString> > files;
179 vector<double> passFracs,nPassCutv,baseWgts;
180 cout << setw(13) << "class" << setw(6) << "type" << setw(12) << " nPass " << setw(8) << "nTot" << setw(10) << " frac" << setw(12) << "baseWgt" << endl;
181 for(unsigned ifile=0; ifile<fileNames.size(); ifile++) {
182 if(fileFracs[ifile] == -1) fileFracs[ifile] = mH_fracs[fileMasses[ifile]];
183 double nPass(0),nTot(0),passFrac(0);
184 if(useWgts) {
185 TH1D hPass("hPass","",100,0,1000);
186 TH1D hTot("hTot","",100,0,1000);
187 TString wgtStr(wgtStrs[ifile].Contains("wInterf") ? "1" : wgtStrs[ifile]); // tree doesn't have a wInterf leaf
188 inTrees[ifile]->Draw("m4l>>hPass",wgtStr+"*"+cutStr);
189 inTrees[ifile]->Draw("m4l>>hTot",wgtStr);
190 passFrac = hPass.Integral() / hTot.Integral();
191 nPass = hPass.Integral();
192 nTot = hTot.Integral();
193 } else {
194 nPass = double(inTrees[ifile]->GetEntries(cutStr));
195 passFrac = nPass / inTrees[ifile]->GetEntries();
196 nTot = inTrees[ifile]->GetEntries();
197 }
198 passFracs.push_back(passFrac);
199 nPassCutv.push_back(nPass);
200 double baseWgt(fileFracs[ifile] / nPass);
201 // if(fsv[ifile]->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest)
202 // baseWgt /= 2; // if we're putting all the fakes into training, divide weights by two
203 baseWgts.push_back(baseWgt);
204 cout << setw(13) << classes[ifile] << "(" << setw(5) << types[ifile] << "): " << setw(10) << nPass << " / " << setw(12) << nTot
205 << setw(6) << setprecision(5) << fileFracs[ifile] << setw(15) << baseWgt
206 << setw(30) << dsets[ifile] << " " << indirs[ifile]
207 << endl;
208 }
209
210 Mela *mela(0);
211 if(useMela) mela = new Mela(false, 8);
212
213 for(unsigned ifs=0; ifs<fsv.size(); ifs++) {
214 filestuff *fs = fsv[ifs];
215 cout << "starting: " << fs->dataset_ << endl;
216 double tot(0);
217 for(unsigned ientry=0; ientry<fs->getentries("zznt"); ientry++) {
218 fs->getentry(ientry,"kinematics","zznt");
219 if(fs->kine->m4l < wL || fs->kine->m4l > wH) continue;
220 fs->getentry(ientry,"","zznt");
221 // if(ientry<100)
222 // cout << fs->ji->nJets << endl; // wasn't filled correctly for some reason
223
224 for(unsigned ivar=0; ivar<vars.size(); ivar++)
225 setTrainingVal(vars[ivar], &varVals[ivar], fs);
226 for(unsigned ispec=vars.size(); ispec<varVals.size(); ispec++)
227 setTrainingVal(specs[ispec-vars.size()], &varVals[ispec], fs);
228
229 // train on even events
230 // if noFakeTest train on *all* the lljj events since we don't have many to start with
231 double wgt(baseWgts[ifs]);
232 if(wgtStrs[ifs] == "w") wgt *= fs->weights->w;
233 else if(wgtStrs[ifs] == "1") wgt *= 1;
234 // else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs, mela, fs->dataset_.Contains("-XXXdkr"), mH);
235 else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs, mela);
236 else assert(0);
237 if(wgt<.000000001 || wgt>5) { cout << "bad wgt?" << wgt << endl;}
238 // if(fs->dataset_.Contains("-dkr") // && ientry<100
239 // )
240 // // cout << "test wgt: " << getInterferenceWeight(fs, mela, fs->dataset_.Contains("-XXXdkr"), mH) << endl;
241 // // cout << "test wgt: " << getInterferenceWeight(fs, mela) << endl;
242
243 if((ientry%2)==0 || (fs->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest))
244 factory->AddEvent(classes[ifs], TMVA::Types::kTraining, varVals, wgt);
245 if((ientry%2)!=0)
246 factory->AddEvent(classes[ifs], TMVA::Types::kTesting, varVals, wgt);
247
248 tot += wgt;
249 }
250 cout << " " << tot << endl;
251 }
252
253 if(NTrees=="") NTrees = "10000";
254 if(NNodesMax=="") NNodesMax = "10";
255 if(Shrinkage=="") Shrinkage = "0.05";
256 if(nCuts=="") nCuts = "20";
257 if(NormMode=="") NormMode = "EqualNumEvents";
258 assert(NormMode=="NumEvents" || NormMode=="EqualNumEvents");
259 factory->PrepareTrainingAndTestTree( "", "SplitMode=Random:V:NormMode="+NormMode);
260 factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:V:NTrees="+NTrees
261 +":NNodesMax="+NNodesMax
262 +":BoostType=Grad:Shrinkage="+Shrinkage
263 +":UseBaggedGrad:GradBaggingFraction=0.50:nCuts="+nCuts
264 // +":nTrain_Signal=100:nTest_Signal=100:nTrain_Background=100:nTest_Background=100:nTrain_Signal2=100:nTest_Signal2=100"
265 +":IgnoreNegWeights");
266
267 factory->TrainAllMethods();
268 factory->TestAllMethods();
269 factory->EvaluateAllMethods();
270
271 std::cerr << "closing ... " << std::endl;
272 outputFile->Close();
273 }
274 //----------------------------------------------------------------------------------------
275 void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi)
276 {
277 double delta = mH_hi - mH_lo;
278 frac_lo = (mH_hi - mH)/delta;
279 frac_hi = (mH - mH_lo)/delta;
280 }
281 //----------------------------------------------------------------------------------------
282 void setTrainingVal(TString name, double *val, filestuff *fs)
283 {
284 if(name=="costheta1") *val = fs->angles->costheta1;
285 else if(name=="costheta2") *val = fs->angles->costheta2;
286 else if(name=="costhetastar") *val = fs->angles->costhetastar;
287 else if(name=="Phi") *val = fs->angles->Phi;
288 else if(name=="Phi1") *val = fs->angles->Phi1;
289 else if(name=="mZ1") *val = fs->kine->mZ1;
290 else if(name=="mZ2") *val = fs->kine->mZ2;
291 // pt variables
292 else if(name=="ZZpt") *val = fs->kine->ZZpt/fs->kine->m4l;
293 else if(name=="ZZdotZ1") *val = fs->kine->ZZdotZ1/(fs->kine->m4l*fs->kine->mZ1);
294 else if(name=="ZZdotZ2") *val = fs->kine->ZZdotZ2/(fs->kine->m4l*fs->kine->mZ2);
295 else if(name=="ZZptCosDphiZ1pt") *val = fs->kine->ZZptZ1ptCosDphi;
296 else if(name=="ZZptCosDphiZ2pt") *val = fs->kine->ZZptZ2ptCosDphi;
297 else if(name=="Z1pt") *val = fs->kine->Z1pt/fs->kine->m4l;
298 else if(name=="Z2pt") *val = fs->kine->Z2pt/fs->kine->m4l;
299 else if(name=="ZZy") *val = fs->kine->ZZy;
300 // jet variables
301 else if(name=="nJets" ) *val = fs->vk->getval("nJets");
302 else if(name=="mjj" ) *val = fs->vk->getval("mjj");
303 else if(name=="dEta" ) *val = fs->vk->getval("dEta");
304 else if(name=="etaProd" ) *val = fs->vk->getval("etaProd");
305 else if(name=="dphiJ1HiPtZ" ) *val = fs->vk->getval("dphiJ1HiPtZ");
306 else if(name=="dphiJ1LoPtZ" ) *val = fs->vk->getval("dphiJ1LoPtZ");
307 else if(name=="dEtaJ1HiPtZ" ) *val = fs->vk->getval("dEtaJ1HiPtZ");
308 else if(name=="dEtaJ1LoPtZ" ) *val = fs->vk->getval("dEtaJ1LoPtZ");
309 else if(name=="J1dotHiPtZ" ) *val = fs->vk->getval("J1dotHiPtZ");
310 else if(name=="J1dotLoPtZ" ) *val = fs->vk->getval("J1dotLoPtZ");
311 else if(name=="dphiJ2HiPtZ" ) *val = fs->vk->getval("dphiJ2HiPtZ");
312 else if(name=="dphiJ2LoPtZ" ) *val = fs->vk->getval("dphiJ2LoPtZ");
313 else if(name=="dEtaJ2HiPtZ" ) *val = fs->vk->getval("dEtaJ2HiPtZ");
314 else if(name=="dEtaJ2LoPtZ" ) *val = fs->vk->getval("dEtaJ2LoPtZ");
315 else if(name=="J2dotHiPtZ" ) *val = fs->vk->getval("J2dotHiPtZ");
316 else if(name=="J2dotLoPtZ" ) *val = fs->vk->getval("J2dotLoPtZ");
317 // spectators
318 else if(name=="m4l") *val = fs->kine->m4l;
319 else if(name=="run") *val = fs->info->run;
320 else if(name=="lumi") *val = fs->info->lumi;
321 else if(name=="evt") *val = fs->info->evt;
322 else { cout << name << " not found!" << endl; assert(0); }
323 }