1 |
#include "TFile.h"
|
2 |
#include "TNtuple.h"
|
3 |
#include "TMath.h"
|
4 |
|
5 |
#include "TMVA/Factory.h"
|
6 |
#include "TMVA/Tools.h"
|
7 |
#include "TMVA/Config.h"
|
8 |
|
9 |
#include <unistd.h>
|
10 |
#include <iostream>
|
11 |
#include <fstream>
|
12 |
#include <string.h>
|
13 |
#include <map>
|
14 |
#include <cassert>
|
15 |
|
16 |
#include "KinematicsStruct.h"
|
17 |
#include "HistHeaders.h"
|
18 |
#include "filestuff.h"
|
19 |
|
20 |
#include "ZZMatrixElement/MELA/interface/Mela.h"
|
21 |
#include "ZZMatrixElement/MELA/interface/PseudoMELA.h"
|
22 |
#include "ZZMatrixElement/MELA/src/computeAngles.h"
|
23 |
|
24 |
using namespace std;
|
25 |
|
26 |
void setTrainingVal(TString name, double *val, filestuff *fs);
|
27 |
void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi);
|
28 |
void setNevents(int nEvtMax, int nSigTrainMax, int nSigTestMax, int nBkgTrainMax, int nBkgTestMax, TString &nSig, TString &nBkg);
|
29 |
|
30 |
int main(int argc, char** argv) {
|
31 |
|
32 |
if( argc < 7 ) {
|
33 |
cerr << "usage: makeBDTWeights.exe <config> <outputrootfile> <weightname> <varstring> <mH> [wL] [wH] [mH_lo] [mH_hi]" << endl;
|
34 |
cerr << "\t... where [wL,wH] are the requested mass window and [mH_lo,mH_hi] are used for interpolation" << endl;
|
35 |
return 1;
|
36 |
}
|
37 |
cout << "argc: " << argc << endl;
|
38 |
for(int iarg=0; iarg<argc; iarg++)
|
39 |
cout << "argv[" << iarg << "]: " << argv[iarg] << endl;
|
40 |
|
41 |
char * config = argv[1];
|
42 |
cout << "config: " << config << endl;
|
43 |
|
44 |
// TString ntupledir;
|
45 |
vector<TString> classes,fileNames,types,indirs,dsets,wgtStrs;
|
46 |
vector<double> fileFracs,fileMasses;
|
47 |
vector<TFile*> inFiles;
|
48 |
vector<TTree*> inTrees;
|
49 |
vector<filestuff*> fsv;
|
50 |
bool multiclass(false);
|
51 |
bool multisigs(false);
|
52 |
bool noFakeTest(false);
|
53 |
bool useMela(false);
|
54 |
TString NTrees(""),NNodesMax(""),Shrinkage(""),nCuts(""),NormMode("");
|
55 |
int nMax=1000000;
|
56 |
ifstream ifs(config);
|
57 |
assert(ifs.is_open());
|
58 |
string line;
|
59 |
while(getline(ifs,line)) {
|
60 |
if(line[0]=='#') continue;
|
61 |
stringstream ss(line);
|
62 |
if(line[0]=='^') {
|
63 |
TString dummy;
|
64 |
// if(TString(line).Contains("^ntupledir")) ss >> dummy >> ntupledir;
|
65 |
if(TString(line).Contains("^multiclass")){ multiclass = true; cout << "multiclass" << multiclass << endl; }
|
66 |
if(TString(line).Contains("^multisigs" )){ multisigs = true; cout << "multisigs" << multisigs << endl; }
|
67 |
if(TString(line).Contains("^noFakeTest")){ noFakeTest = true; cout << "noFakeTest" << noFakeTest << endl; }
|
68 |
if(TString(line).Contains("^NTrees" )){ ss >> dummy >> NTrees; cout << "NTrees" << NTrees << endl; }
|
69 |
if(TString(line).Contains("^NNodesMax" )){ ss >> dummy >> NNodesMax; cout << "NNodesMax" << NNodesMax << endl; }
|
70 |
if(TString(line).Contains("^Shrinkage" )){ ss >> dummy >> Shrinkage; cout << "Shrinkage" << Shrinkage << endl; }
|
71 |
if(TString(line).Contains("^nCuts" )){ ss >> dummy >> nCuts; cout << "nCuts" << nCuts << endl; }
|
72 |
if(TString(line).Contains("^NormMode" )){ ss >> dummy >> NormMode; cout << "NormMode" << NormMode << endl; }
|
73 |
if(TString(line).Contains("^nMax" )){ ss >> dummy >> nMax; cout << "nMax" << nMax << endl; }
|
74 |
if(TString(line).Contains("^useMela" )){ useMela = true; cout << "useMela" << useMela << endl; }
|
75 |
continue;
|
76 |
}
|
77 |
|
78 |
TString cls,indir,type,dset,wgtStr;
|
79 |
double frac,fileMass;
|
80 |
bool isdata;
|
81 |
ss >> cls >> type >> frac >> wgtStr >> indir >> dset >> fileMass >> isdata;
|
82 |
// assert(cls=="sig" || cls=="sig2" || cls=="sig3" || cls=="bkg");
|
83 |
classes.push_back(cls);
|
84 |
types.push_back(type);
|
85 |
indirs.push_back(indir);
|
86 |
dsets.push_back(dset);
|
87 |
TString fname(indir+"/"+dset+"/merged.root");
|
88 |
fsv.push_back(new filestuff(dset, fname, dset, isdata, 2012));
|
89 |
fileNames.push_back(/*ntupledir+"/"+*/fname);
|
90 |
fileFracs.push_back(frac);
|
91 |
wgtStrs.push_back(wgtStr);
|
92 |
fileMasses.push_back(fileMass);
|
93 |
cout << " pushing back: " << cls << " " << setw(12) << frac << " " << fname << " " << fileMass << endl;
|
94 |
inFiles.push_back(TFile::Open(fileNames.back()));
|
95 |
if(!inFiles.back()->IsOpen()) {
|
96 |
cout << fileNames.back() << " not open!" << endl;
|
97 |
assert(0);
|
98 |
}
|
99 |
inTrees.push_back((TTree*)inFiles.back()->Get("zznt"));
|
100 |
assert(inTrees.back());
|
101 |
}
|
102 |
ifs.close();
|
103 |
|
104 |
char * outputrootfile = argv[2];
|
105 |
char * weightname = argv[3];
|
106 |
char * varstring = argv[4];
|
107 |
float mH = strtof(argv[5],NULL);
|
108 |
float wH=900, wL=100;
|
109 |
if( argc > 6 ) {
|
110 |
wL = strtof(argv[6], NULL);
|
111 |
wH = strtof(argv[7], NULL);
|
112 |
}
|
113 |
float mH_lo=0,mH_hi=0;
|
114 |
float frac_lo=1,frac_hi=1;
|
115 |
map<double,double> mH_fracs;
|
116 |
mH_fracs[-1] = 1; // this means zz
|
117 |
assert( argc > 8);
|
118 |
mH_lo = strtof(argv[8],NULL);
|
119 |
mH_hi = strtof(argv[9],NULL);
|
120 |
if(mH_lo!=0 && mH_hi!=0) {
|
121 |
// set the fraction to use for each input file, with more events from the closer mass point
|
122 |
// NOTE: we only use these fractions if the fraction from the config file is -1
|
123 |
setFractions(mH,mH_lo,mH_hi,frac_lo,frac_hi);
|
124 |
mH_fracs[mH_lo] = frac_lo;
|
125 |
mH_fracs[mH_hi] = frac_hi;
|
126 |
} else {
|
127 |
mH_fracs[mH] = 1;
|
128 |
}
|
129 |
|
130 |
cerr
|
131 |
<< "name: " << weightname << endl
|
132 |
<< "output: " << outputrootfile << endl
|
133 |
<< "variables: " << varstring << endl
|
134 |
<< "mH: " << mH << endl
|
135 |
<< "mass window:" << endl
|
136 |
<< " wL: " << wL << endl
|
137 |
<< " wH: " << wH << endl
|
138 |
<< "masses from which to interpolate:" << endl
|
139 |
<< " mH_lo: " << mH_lo << endl
|
140 |
<< " mH_hi: " << mH_hi << endl
|
141 |
<< "weights for interpolation:" << endl
|
142 |
<< " frac_lo: " << frac_lo << endl
|
143 |
<< " frac_hi: " << frac_hi << endl
|
144 |
<< endl;
|
145 |
|
146 |
// get the list of variables we're going to use
|
147 |
vector<double> varVals;
|
148 |
TString varStr(varstring);
|
149 |
varStr.ReplaceAll(":"," ");
|
150 |
vector<TString> vars,specs;
|
151 |
stringstream ss(varStr.Data());
|
152 |
while(!ss.eof()) {
|
153 |
TString var;
|
154 |
ss >> var;
|
155 |
vars.push_back(var);
|
156 |
cout << "VAR: " << var << endl;
|
157 |
varVals.push_back(double(0));
|
158 |
}
|
159 |
specs.push_back("m4l"); varVals.push_back(double(0));
|
160 |
specs.push_back("run"); varVals.push_back(double(0));
|
161 |
specs.push_back("lumi"); varVals.push_back(double(0));
|
162 |
specs.push_back("evt"); varVals.push_back(double(0));
|
163 |
assert(varVals.size() == (vars.size() + specs.size()));
|
164 |
|
165 |
TFile *outputFile = new TFile(outputrootfile, "RECREATE");
|
166 |
TMVA::Tools::Instance();
|
167 |
TString facStr("!V:!Silent:!DrawProgressBar");
|
168 |
if(multiclass) facStr = facStr+":AnalysisType=multiclass";
|
169 |
TMVA::Factory* factory = new TMVA::Factory(weightname, outputFile, facStr);
|
170 |
for(unsigned ivar=0; ivar<vars.size(); ivar++) factory->AddVariable( vars[ivar], 'F');
|
171 |
for(unsigned ispec=0; ispec<specs.size(); ispec++) factory->AddSpectator(specs[ispec], 'F');
|
172 |
|
173 |
char buf[256];
|
174 |
sprintf( buf, "(m4l>%f && m4l<%f)", (float)wL, (float)wH );
|
175 |
TString cutStr(buf);
|
176 |
cout << "cutStr: " << cutStr << endl;
|
177 |
|
178 |
bool useWgts(true);
|
179 |
vector<vector<TString> > files;
|
180 |
vector<double> passFracs,nPassCutv,baseWgts;
|
181 |
cout << setw(13) << "class" << setw(6) << "type" << setw(12) << " nPass " << setw(8) << "nTot" << setw(10) << " frac" << setw(12) << "baseWgt" << endl;
|
182 |
for(unsigned ifile=0; ifile<fileNames.size(); ifile++) {
|
183 |
if(fileFracs[ifile] == -1) fileFracs[ifile] = mH_fracs[fileMasses[ifile]];
|
184 |
double nPass(0),nTot(0),passFrac(0);
|
185 |
if(useWgts) {
|
186 |
TH1D hPass("hPass","",100,0,1000);
|
187 |
TH1D hTot("hTot","",100,0,1000);
|
188 |
TString wgtStr(wgtStrs[ifile].Contains("wInterf") ? "1" : wgtStrs[ifile]); // tree doesn't have a wInterf leaf
|
189 |
inTrees[ifile]->Draw("m4l>>hPass",wgtStr+"*"+cutStr);
|
190 |
inTrees[ifile]->Draw("m4l>>hTot",wgtStr);
|
191 |
passFrac = hPass.Integral() / hTot.Integral();
|
192 |
nPass = hPass.Integral();
|
193 |
nTot = hTot.Integral();
|
194 |
} else {
|
195 |
nPass = double(inTrees[ifile]->GetEntries(cutStr));
|
196 |
passFrac = nPass / inTrees[ifile]->GetEntries();
|
197 |
nTot = inTrees[ifile]->GetEntries();
|
198 |
}
|
199 |
passFracs.push_back(passFrac);
|
200 |
nPassCutv.push_back(nPass);
|
201 |
double baseWgt(fileFracs[ifile] / nPass);
|
202 |
// if(fsv[ifile]->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest)
|
203 |
// baseWgt /= 2; // if we're putting all the fakes into training, divide weights by two
|
204 |
baseWgts.push_back(baseWgt);
|
205 |
cout << setw(13) << classes[ifile] << "(" << setw(5) << types[ifile] << "): " << setw(10) << nPass << " / " << setw(12) << nTot
|
206 |
<< setw(6) << setprecision(5) << fileFracs[ifile] << setw(15) << baseWgt
|
207 |
<< setw(30) << dsets[ifile] << " " << indirs[ifile]
|
208 |
<< endl;
|
209 |
}
|
210 |
|
211 |
Mela *mela(0);
|
212 |
if(useMela) mela = new Mela(false, 8);
|
213 |
|
214 |
for(unsigned ifs=0; ifs<fsv.size(); ifs++) {
|
215 |
filestuff *fs = fsv[ifs];
|
216 |
cout << "starting: " << fs->dataset_ << endl;
|
217 |
double tot(0);
|
218 |
for(unsigned ientry=0; ientry<fs->getentries("zznt"); ientry++) {
|
219 |
fs->getentry(ientry,"kinematics","zznt");
|
220 |
if(fs->kine->m4l < wL || fs->kine->m4l > wH) continue;
|
221 |
fs->getentry(ientry,"","zznt");
|
222 |
if(ientry<100)
|
223 |
cout << fs->ji->nJets << endl;
|
224 |
|
225 |
for(unsigned ivar=0; ivar<vars.size(); ivar++)
|
226 |
setTrainingVal(vars[ivar], &varVals[ivar], fs);
|
227 |
for(unsigned ispec=vars.size(); ispec<varVals.size(); ispec++)
|
228 |
setTrainingVal(specs[ispec-vars.size()], &varVals[ispec], fs);
|
229 |
|
230 |
// train on even events
|
231 |
// if noFakeTest train on *all* the lljj events since we don't have many to start with
|
232 |
double wgt(baseWgts[ifs]);
|
233 |
if(wgtStrs[ifs] == "w") wgt *= fs->weights->w;
|
234 |
else if(wgtStrs[ifs] == "1") wgt *= 1;
|
235 |
else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs,mela);
|
236 |
else assert(0);
|
237 |
if(wgt<.000000001 || wgt>5) { cout << "bad wgt?" << wgt << endl;}
|
238 |
|
239 |
if((ientry%2)==0 || (fs->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest))
|
240 |
factory->AddEvent(classes[ifs], TMVA::Types::kTraining, varVals, wgt);
|
241 |
if((ientry%2)!=0)
|
242 |
factory->AddEvent(classes[ifs], TMVA::Types::kTesting, varVals, wgt);
|
243 |
|
244 |
tot += wgt;
|
245 |
}
|
246 |
cout << " " << tot << endl;
|
247 |
}
|
248 |
|
249 |
if(NTrees=="") NTrees = "10000";
|
250 |
if(NNodesMax=="") NNodesMax = "10";
|
251 |
if(Shrinkage=="") Shrinkage = "0.05";
|
252 |
if(nCuts=="") nCuts = "20";
|
253 |
if(NormMode=="") NormMode = "EqualNumEvents";
|
254 |
assert(NormMode=="NumEvents" || NormMode=="EqualNumEvents");
|
255 |
factory->PrepareTrainingAndTestTree( "", "SplitMode=Random:V:NormMode="+NormMode);
|
256 |
factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:V:NTrees="+NTrees
|
257 |
+":NNodesMax="+NNodesMax
|
258 |
+":BoostType=Grad:Shrinkage="+Shrinkage
|
259 |
+":UseBaggedGrad:GradBaggingFraction=0.50:nCuts="+nCuts
|
260 |
// +":nTrain_Signal=100:nTest_Signal=100:nTrain_Background=100:nTest_Background=100:nTrain_Signal2=100:nTest_Signal2=100"
|
261 |
+":IgnoreNegWeights");
|
262 |
|
263 |
factory->TrainAllMethods();
|
264 |
factory->TestAllMethods();
|
265 |
factory->EvaluateAllMethods();
|
266 |
|
267 |
std::cerr << "closing ... " << std::endl;
|
268 |
outputFile->Close();
|
269 |
}
|
270 |
//----------------------------------------------------------------------------------------
|
271 |
void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi)
|
272 |
{
|
273 |
double delta = mH_hi - mH_lo;
|
274 |
frac_lo = (mH_hi - mH)/delta;
|
275 |
frac_hi = (mH - mH_lo)/delta;
|
276 |
}
|
277 |
//----------------------------------------------------------------------------------------
|
278 |
void setNevents(int nEvtMax, int nSigTrainMax, int nSigTestMax, int nBkgTrainMax, int nBkgTestMax, TString &nSig, TString &nBkg)
|
279 |
{
|
280 |
// we want nEvtMax, but if there's fewer than this in any of the trees, we reduce our expectations
|
281 |
if(fabs(nSigTrainMax-nSigTestMax)/double(nSigTrainMax) > 0.1) {
|
282 |
assert(0);
|
283 |
}
|
284 |
if(fabs(nBkgTrainMax-nBkgTestMax)/double(nBkgTrainMax) > 0.1) {
|
285 |
assert(0);
|
286 |
}
|
287 |
|
288 |
nEvtMax = min(nEvtMax,nSigTrainMax);
|
289 |
nEvtMax = min(nEvtMax,nSigTestMax);
|
290 |
nEvtMax = min(nEvtMax,nBkgTrainMax);
|
291 |
nEvtMax = min(nEvtMax,nBkgTestMax);
|
292 |
stringstream ss;
|
293 |
ss << nEvtMax;
|
294 |
ss >> nSig;
|
295 |
nBkg = nSig;
|
296 |
// ss << min(nSigTrainMax,nSigTestMax) << " " << min(nBkgTrainMax,nBkgTestMax);
|
297 |
// ss >> nSig >> nBkg;
|
298 |
}
|
299 |
//----------------------------------------------------------------------------------------
|
300 |
// void setTrainingVals(vector<TString> vars, vector<double> &varVals, vector<TString> specs, vector<double> &specVals, filestuff *fs)
|
301 |
void setTrainingVal(TString name, double *val, filestuff *fs)
|
302 |
{
|
303 |
if(name=="costheta1") *val = fs->angles->costheta1;
|
304 |
else if(name=="costheta2") *val = fs->angles->costheta2;
|
305 |
else if(name=="costhetastar") *val = fs->angles->costhetastar;
|
306 |
else if(name=="Phi") *val = fs->angles->Phi;
|
307 |
else if(name=="Phi1") *val = fs->angles->Phi1;
|
308 |
else if(name=="mZ1") *val = fs->kine->mZ1;
|
309 |
else if(name=="mZ2") *val = fs->kine->mZ2;
|
310 |
// pt variables
|
311 |
// else if(name=="ZZpt/m4l") *val = fs->kine->ZZpt/fs->kine->m4l;
|
312 |
// else if(name=="ZZdotZ1/(m4l*mZ1)") *val = fs->kine->ZZdotZ1/(fs->kine->m4l*fs->kine->mZ1);
|
313 |
// else if(name=="ZZdotZ2/(m4l*mZ2)") *val = fs->kine->ZZdotZ2/(fs->kine->m4l*fs->kine->mZ2);
|
314 |
else if(name=="ZZpt") *val = fs->kine->ZZpt/fs->kine->m4l;
|
315 |
else if(name=="ZZdotZ1") *val = fs->kine->ZZdotZ1/(fs->kine->m4l*fs->kine->mZ1);
|
316 |
else if(name=="ZZdotZ2") *val = fs->kine->ZZdotZ2/(fs->kine->m4l*fs->kine->mZ2);
|
317 |
else if(name=="ZZptCosDphiZ1pt") *val = fs->kine->ZZptZ1ptCosDphi;
|
318 |
else if(name=="ZZptCosDphiZ2pt") *val = fs->kine->ZZptZ2ptCosDphi;
|
319 |
// else if(name=="Z1pt/m4l") *val = fs->kine->Z1pt/fs->kine->m4l;
|
320 |
// else if(name=="Z2pt/m4l") *val = fs->kine->Z2pt/fs->kine->m4l;
|
321 |
else if(name=="Z1pt") *val = fs->kine->Z1pt/fs->kine->m4l;
|
322 |
else if(name=="Z2pt") *val = fs->kine->Z2pt/fs->kine->m4l;
|
323 |
else if(name=="ZZy") *val = fs->kine->ZZy;
|
324 |
// jet variables
|
325 |
else if(name=="nJets" ) *val = fs->vk->getval("nJets");
|
326 |
else if(name=="mjj" ) *val = fs->vk->getval("mjj");
|
327 |
else if(name=="dEta" ) *val = fs->vk->getval("dEta");
|
328 |
else if(name=="etaProd" ) *val = fs->vk->getval("etaProd");
|
329 |
else if(name=="dphiJ1HiPtZ" ) *val = fs->vk->getval("dphiJ1HiPtZ");
|
330 |
else if(name=="dphiJ1LoPtZ" ) *val = fs->vk->getval("dphiJ1LoPtZ");
|
331 |
else if(name=="dEtaJ1HiPtZ" ) *val = fs->vk->getval("dEtaJ1HiPtZ");
|
332 |
else if(name=="dEtaJ1LoPtZ" ) *val = fs->vk->getval("dEtaJ1LoPtZ");
|
333 |
else if(name=="J1dotHiPtZ" ) *val = fs->vk->getval("J1dotHiPtZ");
|
334 |
else if(name=="J1dotLoPtZ" ) *val = fs->vk->getval("J1dotLoPtZ");
|
335 |
else if(name=="dphiJ2HiPtZ" ) *val = fs->vk->getval("dphiJ2HiPtZ");
|
336 |
else if(name=="dphiJ2LoPtZ" ) *val = fs->vk->getval("dphiJ2LoPtZ");
|
337 |
else if(name=="dEtaJ2HiPtZ" ) *val = fs->vk->getval("dEtaJ2HiPtZ");
|
338 |
else if(name=="dEtaJ2LoPtZ" ) *val = fs->vk->getval("dEtaJ2LoPtZ");
|
339 |
else if(name=="J2dotHiPtZ" ) *val = fs->vk->getval("J2dotHiPtZ");
|
340 |
else if(name=="J2dotLoPtZ" ) *val = fs->vk->getval("J2dotLoPtZ");
|
341 |
// spectators
|
342 |
else if(name=="m4l") *val = fs->kine->m4l;
|
343 |
else if(name=="run") *val = fs->info->run;
|
344 |
else if(name=="lumi") *val = fs->info->lumi;
|
345 |
else if(name=="evt") *val = fs->info->evt;
|
346 |
else { cout << name << " not found!" << endl; assert(0); }
|
347 |
}
|