ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/MitHzz4l/Angles/src/trainAngles.cc
Revision: 1.4
Committed: Tue Jan 15 09:40:41 2013 UTC (12 years, 4 months ago) by dkralph
Content type: text/plain
Branch: MAIN
Changes since 1.3: +168 -242 lines
Log Message:
*** empty log message ***

File Contents

# Content
1 #include "TFile.h"
2 #include "TNtuple.h"
3 #include "TMath.h"
4
5 #include "TMVA/Factory.h"
6 #include "TMVA/Tools.h"
7 #include "TMVA/Config.h"
8
9 #include <unistd.h>
10 #include <iostream>
11 #include <fstream>
12 #include <string.h>
13 #include <map>
14 #include <cassert>
15
16 #include "KinematicsStruct.h"
17 #include "HistHeaders.h"
18 #include "filestuff.h"
19
20 #include "ZZMatrixElement/MELA/interface/Mela.h"
21 #include "ZZMatrixElement/MELA/interface/PseudoMELA.h"
22 #include "ZZMatrixElement/MELA/src/computeAngles.h"
23
24 using namespace std;
25
26 void setTrainingVal(TString name, double *val, filestuff *fs);
27 void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi);
28 void setNevents(int nEvtMax, int nSigTrainMax, int nSigTestMax, int nBkgTrainMax, int nBkgTestMax, TString &nSig, TString &nBkg);
29
30 int main(int argc, char** argv) {
31
32 if( argc < 7 ) {
33 cerr << "usage: makeBDTWeights.exe <config> <outputrootfile> <weightname> <varstring> <mH> [wL] [wH] [mH_lo] [mH_hi]" << endl;
34 cerr << "\t... where [wL,wH] are the requested mass window and [mH_lo,mH_hi] are used for interpolation" << endl;
35 return 1;
36 }
37 cout << "argc: " << argc << endl;
38 for(int iarg=0; iarg<argc; iarg++)
39 cout << "argv[" << iarg << "]: " << argv[iarg] << endl;
40
41 char * config = argv[1];
42 cout << "config: " << config << endl;
43
44 // TString ntupledir;
45 vector<TString> classes,fileNames,types,indirs,dsets,wgtStrs;
46 vector<double> fileFracs,fileMasses;
47 vector<TFile*> inFiles;
48 vector<TTree*> inTrees;
49 vector<filestuff*> fsv;
50 bool multiclass(false);
51 bool multisigs(false);
52 bool noFakeTest(false);
53 bool useMela(false);
54 TString NTrees(""),NNodesMax(""),Shrinkage(""),nCuts(""),NormMode("");
55 int nMax=1000000;
56 ifstream ifs(config);
57 assert(ifs.is_open());
58 string line;
59 while(getline(ifs,line)) {
60 if(line[0]=='#') continue;
61 stringstream ss(line);
62 if(line[0]=='^') {
63 TString dummy;
64 // if(TString(line).Contains("^ntupledir")) ss >> dummy >> ntupledir;
65 if(TString(line).Contains("^multiclass")){ multiclass = true; cout << "multiclass" << multiclass << endl; }
66 if(TString(line).Contains("^multisigs" )){ multisigs = true; cout << "multisigs" << multisigs << endl; }
67 if(TString(line).Contains("^noFakeTest")){ noFakeTest = true; cout << "noFakeTest" << noFakeTest << endl; }
68 if(TString(line).Contains("^NTrees" )){ ss >> dummy >> NTrees; cout << "NTrees" << NTrees << endl; }
69 if(TString(line).Contains("^NNodesMax" )){ ss >> dummy >> NNodesMax; cout << "NNodesMax" << NNodesMax << endl; }
70 if(TString(line).Contains("^Shrinkage" )){ ss >> dummy >> Shrinkage; cout << "Shrinkage" << Shrinkage << endl; }
71 if(TString(line).Contains("^nCuts" )){ ss >> dummy >> nCuts; cout << "nCuts" << nCuts << endl; }
72 if(TString(line).Contains("^NormMode" )){ ss >> dummy >> NormMode; cout << "NormMode" << NormMode << endl; }
73 if(TString(line).Contains("^nMax" )){ ss >> dummy >> nMax; cout << "nMax" << nMax << endl; }
74 if(TString(line).Contains("^useMela" )){ useMela = true; cout << "useMela" << useMela << endl; }
75 continue;
76 }
77
78 TString cls,indir,type,dset,wgtStr;
79 double frac,fileMass;
80 bool isdata;
81 ss >> cls >> type >> frac >> wgtStr >> indir >> dset >> fileMass >> isdata;
82 // assert(cls=="sig" || cls=="sig2" || cls=="sig3" || cls=="bkg");
83 classes.push_back(cls);
84 types.push_back(type);
85 indirs.push_back(indir);
86 dsets.push_back(dset);
87 TString fname(indir+"/"+dset+"/merged.root");
88 fsv.push_back(new filestuff(dset, fname, dset, isdata, 2012));
89 fileNames.push_back(/*ntupledir+"/"+*/fname);
90 fileFracs.push_back(frac);
91 wgtStrs.push_back(wgtStr);
92 fileMasses.push_back(fileMass);
93 cout << " pushing back: " << cls << " " << setw(12) << frac << " " << fname << " " << fileMass << endl;
94 inFiles.push_back(TFile::Open(fileNames.back()));
95 if(!inFiles.back()->IsOpen()) {
96 cout << fileNames.back() << " not open!" << endl;
97 assert(0);
98 }
99 inTrees.push_back((TTree*)inFiles.back()->Get("zznt"));
100 assert(inTrees.back());
101 }
102 ifs.close();
103
104 char * outputrootfile = argv[2];
105 char * weightname = argv[3];
106 char * varstring = argv[4];
107 float mH = strtof(argv[5],NULL);
108 float wH=900, wL=100;
109 if( argc > 6 ) {
110 wL = strtof(argv[6], NULL);
111 wH = strtof(argv[7], NULL);
112 }
113 float mH_lo=0,mH_hi=0;
114 float frac_lo=1,frac_hi=1;
115 map<double,double> mH_fracs;
116 mH_fracs[-1] = 1; // this means zz
117 assert( argc > 8);
118 mH_lo = strtof(argv[8],NULL);
119 mH_hi = strtof(argv[9],NULL);
120 if(mH_lo!=0 && mH_hi!=0) {
121 // set the fraction to use for each input file, with more events from the closer mass point
122 // NOTE: we only use these fractions if the fraction from the config file is -1
123 setFractions(mH,mH_lo,mH_hi,frac_lo,frac_hi);
124 mH_fracs[mH_lo] = frac_lo;
125 mH_fracs[mH_hi] = frac_hi;
126 } else {
127 mH_fracs[mH] = 1;
128 }
129
130 cerr
131 << "name: " << weightname << endl
132 << "output: " << outputrootfile << endl
133 << "variables: " << varstring << endl
134 << "mH: " << mH << endl
135 << "mass window:" << endl
136 << " wL: " << wL << endl
137 << " wH: " << wH << endl
138 << "masses from which to interpolate:" << endl
139 << " mH_lo: " << mH_lo << endl
140 << " mH_hi: " << mH_hi << endl
141 << "weights for interpolation:" << endl
142 << " frac_lo: " << frac_lo << endl
143 << " frac_hi: " << frac_hi << endl
144 << endl;
145
146 // get the list of variables we're going to use
147 vector<double> varVals;
148 TString varStr(varstring);
149 varStr.ReplaceAll(":"," ");
150 vector<TString> vars,specs;
151 stringstream ss(varStr.Data());
152 while(!ss.eof()) {
153 TString var;
154 ss >> var;
155 vars.push_back(var);
156 cout << "VAR: " << var << endl;
157 varVals.push_back(double(0));
158 }
159 specs.push_back("m4l"); varVals.push_back(double(0));
160 specs.push_back("run"); varVals.push_back(double(0));
161 specs.push_back("lumi"); varVals.push_back(double(0));
162 specs.push_back("evt"); varVals.push_back(double(0));
163 assert(varVals.size() == (vars.size() + specs.size()));
164
165 TFile *outputFile = new TFile(outputrootfile, "RECREATE");
166 TMVA::Tools::Instance();
167 TString facStr("!V:!Silent:!DrawProgressBar");
168 if(multiclass) facStr = facStr+":AnalysisType=multiclass";
169 TMVA::Factory* factory = new TMVA::Factory(weightname, outputFile, facStr);
170 for(unsigned ivar=0; ivar<vars.size(); ivar++) factory->AddVariable( vars[ivar], 'F');
171 for(unsigned ispec=0; ispec<specs.size(); ispec++) factory->AddSpectator(specs[ispec], 'F');
172
173 char buf[256];
174 sprintf( buf, "(m4l>%f && m4l<%f)", (float)wL, (float)wH );
175 TString cutStr(buf);
176 cout << "cutStr: " << cutStr << endl;
177
178 bool useWgts(true);
179 vector<vector<TString> > files;
180 vector<double> passFracs,nPassCutv,baseWgts;
181 cout << setw(13) << "class" << setw(6) << "type" << setw(12) << " nPass " << setw(8) << "nTot" << setw(10) << " frac" << setw(12) << "baseWgt" << endl;
182 for(unsigned ifile=0; ifile<fileNames.size(); ifile++) {
183 if(fileFracs[ifile] == -1) fileFracs[ifile] = mH_fracs[fileMasses[ifile]];
184 double nPass(0),nTot(0),passFrac(0);
185 if(useWgts) {
186 TH1D hPass("hPass","",100,0,1000);
187 TH1D hTot("hTot","",100,0,1000);
188 TString wgtStr(wgtStrs[ifile].Contains("wInterf") ? "1" : wgtStrs[ifile]); // tree doesn't have a wInterf leaf
189 inTrees[ifile]->Draw("m4l>>hPass",wgtStr+"*"+cutStr);
190 inTrees[ifile]->Draw("m4l>>hTot",wgtStr);
191 passFrac = hPass.Integral() / hTot.Integral();
192 nPass = hPass.Integral();
193 nTot = hTot.Integral();
194 } else {
195 nPass = double(inTrees[ifile]->GetEntries(cutStr));
196 passFrac = nPass / inTrees[ifile]->GetEntries();
197 nTot = inTrees[ifile]->GetEntries();
198 }
199 passFracs.push_back(passFrac);
200 nPassCutv.push_back(nPass);
201 double baseWgt(fileFracs[ifile] / nPass);
202 // if(fsv[ifile]->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest)
203 // baseWgt /= 2; // if we're putting all the fakes into training, divide weights by two
204 baseWgts.push_back(baseWgt);
205 cout << setw(13) << classes[ifile] << "(" << setw(5) << types[ifile] << "): " << setw(10) << nPass << " / " << setw(12) << nTot
206 << setw(6) << setprecision(5) << fileFracs[ifile] << setw(15) << baseWgt
207 << setw(30) << dsets[ifile] << " " << indirs[ifile]
208 << endl;
209 }
210
211 Mela *mela(0);
212 if(useMela) mela = new Mela(false, 8);
213
214 for(unsigned ifs=0; ifs<fsv.size(); ifs++) {
215 filestuff *fs = fsv[ifs];
216 cout << "starting: " << fs->dataset_ << endl;
217 double tot(0);
218 for(unsigned ientry=0; ientry<fs->getentries("zznt"); ientry++) {
219 fs->getentry(ientry,"kinematics","zznt");
220 if(fs->kine->m4l < wL || fs->kine->m4l > wH) continue;
221 fs->getentry(ientry,"","zznt");
222 if(ientry<100)
223 cout << fs->ji->nJets << endl;
224
225 for(unsigned ivar=0; ivar<vars.size(); ivar++)
226 setTrainingVal(vars[ivar], &varVals[ivar], fs);
227 for(unsigned ispec=vars.size(); ispec<varVals.size(); ispec++)
228 setTrainingVal(specs[ispec-vars.size()], &varVals[ispec], fs);
229
230 // train on even events
231 // if noFakeTest train on *all* the lljj events since we don't have many to start with
232 double wgt(baseWgts[ifs]);
233 if(wgtStrs[ifs] == "w") wgt *= fs->weights->w;
234 else if(wgtStrs[ifs] == "1") wgt *= 1;
235 else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs,mela);
236 else assert(0);
237 if(wgt<.000000001 || wgt>5) { cout << "bad wgt?" << wgt << endl;}
238
239 if((ientry%2)==0 || (fs->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest))
240 factory->AddEvent(classes[ifs], TMVA::Types::kTraining, varVals, wgt);
241 if((ientry%2)!=0)
242 factory->AddEvent(classes[ifs], TMVA::Types::kTesting, varVals, wgt);
243
244 tot += wgt;
245 }
246 cout << " " << tot << endl;
247 }
248
249 if(NTrees=="") NTrees = "10000";
250 if(NNodesMax=="") NNodesMax = "10";
251 if(Shrinkage=="") Shrinkage = "0.05";
252 if(nCuts=="") nCuts = "20";
253 if(NormMode=="") NormMode = "EqualNumEvents";
254 assert(NormMode=="NumEvents" || NormMode=="EqualNumEvents");
255 factory->PrepareTrainingAndTestTree( "", "SplitMode=Random:V:NormMode="+NormMode);
256 factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:V:NTrees="+NTrees
257 +":NNodesMax="+NNodesMax
258 +":BoostType=Grad:Shrinkage="+Shrinkage
259 +":UseBaggedGrad:GradBaggingFraction=0.50:nCuts="+nCuts
260 // +":nTrain_Signal=100:nTest_Signal=100:nTrain_Background=100:nTest_Background=100:nTrain_Signal2=100:nTest_Signal2=100"
261 +":IgnoreNegWeights");
262
263 factory->TrainAllMethods();
264 factory->TestAllMethods();
265 factory->EvaluateAllMethods();
266
267 std::cerr << "closing ... " << std::endl;
268 outputFile->Close();
269 }
270 //----------------------------------------------------------------------------------------
271 void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi)
272 {
273 double delta = mH_hi - mH_lo;
274 frac_lo = (mH_hi - mH)/delta;
275 frac_hi = (mH - mH_lo)/delta;
276 }
277 //----------------------------------------------------------------------------------------
278 void setNevents(int nEvtMax, int nSigTrainMax, int nSigTestMax, int nBkgTrainMax, int nBkgTestMax, TString &nSig, TString &nBkg)
279 {
280 // we want nEvtMax, but if there's fewer than this in any of the trees, we reduce our expectations
281 if(fabs(nSigTrainMax-nSigTestMax)/double(nSigTrainMax) > 0.1) {
282 assert(0);
283 }
284 if(fabs(nBkgTrainMax-nBkgTestMax)/double(nBkgTrainMax) > 0.1) {
285 assert(0);
286 }
287
288 nEvtMax = min(nEvtMax,nSigTrainMax);
289 nEvtMax = min(nEvtMax,nSigTestMax);
290 nEvtMax = min(nEvtMax,nBkgTrainMax);
291 nEvtMax = min(nEvtMax,nBkgTestMax);
292 stringstream ss;
293 ss << nEvtMax;
294 ss >> nSig;
295 nBkg = nSig;
296 // ss << min(nSigTrainMax,nSigTestMax) << " " << min(nBkgTrainMax,nBkgTestMax);
297 // ss >> nSig >> nBkg;
298 }
299 //----------------------------------------------------------------------------------------
300 // void setTrainingVals(vector<TString> vars, vector<double> &varVals, vector<TString> specs, vector<double> &specVals, filestuff *fs)
301 void setTrainingVal(TString name, double *val, filestuff *fs)
302 {
303 if(name=="costheta1") *val = fs->angles->costheta1;
304 else if(name=="costheta2") *val = fs->angles->costheta2;
305 else if(name=="costhetastar") *val = fs->angles->costhetastar;
306 else if(name=="Phi") *val = fs->angles->Phi;
307 else if(name=="Phi1") *val = fs->angles->Phi1;
308 else if(name=="mZ1") *val = fs->kine->mZ1;
309 else if(name=="mZ2") *val = fs->kine->mZ2;
310 // pt variables
311 // else if(name=="ZZpt/m4l") *val = fs->kine->ZZpt/fs->kine->m4l;
312 // else if(name=="ZZdotZ1/(m4l*mZ1)") *val = fs->kine->ZZdotZ1/(fs->kine->m4l*fs->kine->mZ1);
313 // else if(name=="ZZdotZ2/(m4l*mZ2)") *val = fs->kine->ZZdotZ2/(fs->kine->m4l*fs->kine->mZ2);
314 else if(name=="ZZpt") *val = fs->kine->ZZpt/fs->kine->m4l;
315 else if(name=="ZZdotZ1") *val = fs->kine->ZZdotZ1/(fs->kine->m4l*fs->kine->mZ1);
316 else if(name=="ZZdotZ2") *val = fs->kine->ZZdotZ2/(fs->kine->m4l*fs->kine->mZ2);
317 else if(name=="ZZptCosDphiZ1pt") *val = fs->kine->ZZptZ1ptCosDphi;
318 else if(name=="ZZptCosDphiZ2pt") *val = fs->kine->ZZptZ2ptCosDphi;
319 // else if(name=="Z1pt/m4l") *val = fs->kine->Z1pt/fs->kine->m4l;
320 // else if(name=="Z2pt/m4l") *val = fs->kine->Z2pt/fs->kine->m4l;
321 else if(name=="Z1pt") *val = fs->kine->Z1pt/fs->kine->m4l;
322 else if(name=="Z2pt") *val = fs->kine->Z2pt/fs->kine->m4l;
323 else if(name=="ZZy") *val = fs->kine->ZZy;
324 // jet variables
325 else if(name=="nJets" ) *val = fs->vk->getval("nJets");
326 else if(name=="mjj" ) *val = fs->vk->getval("mjj");
327 else if(name=="dEta" ) *val = fs->vk->getval("dEta");
328 else if(name=="etaProd" ) *val = fs->vk->getval("etaProd");
329 else if(name=="dphiJ1HiPtZ" ) *val = fs->vk->getval("dphiJ1HiPtZ");
330 else if(name=="dphiJ1LoPtZ" ) *val = fs->vk->getval("dphiJ1LoPtZ");
331 else if(name=="dEtaJ1HiPtZ" ) *val = fs->vk->getval("dEtaJ1HiPtZ");
332 else if(name=="dEtaJ1LoPtZ" ) *val = fs->vk->getval("dEtaJ1LoPtZ");
333 else if(name=="J1dotHiPtZ" ) *val = fs->vk->getval("J1dotHiPtZ");
334 else if(name=="J1dotLoPtZ" ) *val = fs->vk->getval("J1dotLoPtZ");
335 else if(name=="dphiJ2HiPtZ" ) *val = fs->vk->getval("dphiJ2HiPtZ");
336 else if(name=="dphiJ2LoPtZ" ) *val = fs->vk->getval("dphiJ2LoPtZ");
337 else if(name=="dEtaJ2HiPtZ" ) *val = fs->vk->getval("dEtaJ2HiPtZ");
338 else if(name=="dEtaJ2LoPtZ" ) *val = fs->vk->getval("dEtaJ2LoPtZ");
339 else if(name=="J2dotHiPtZ" ) *val = fs->vk->getval("J2dotHiPtZ");
340 else if(name=="J2dotLoPtZ" ) *val = fs->vk->getval("J2dotLoPtZ");
341 // spectators
342 else if(name=="m4l") *val = fs->kine->m4l;
343 else if(name=="run") *val = fs->info->run;
344 else if(name=="lumi") *val = fs->info->lumi;
345 else if(name=="evt") *val = fs->info->evt;
346 else { cout << name << " not found!" << endl; assert(0); }
347 }