1 |
#include "TFile.h"
|
2 |
#include "TNtuple.h"
|
3 |
#include "TMath.h"
|
4 |
|
5 |
#include "TMVA/Factory.h"
|
6 |
#include "TMVA/Tools.h"
|
7 |
#include "TMVA/Config.h"
|
8 |
|
9 |
#include <unistd.h>
|
10 |
#include <iostream>
|
11 |
#include <fstream>
|
12 |
#include <string.h>
|
13 |
#include <map>
|
14 |
#include <cassert>
|
15 |
|
16 |
#include "KinematicsStruct.h"
|
17 |
#include "HistHeaders.h"
|
18 |
#include "filestuff.h"
|
19 |
|
20 |
#include "ZZMatrixElement/MELA/interface/Mela.h"
|
21 |
#include "ZZMatrixElement/MELA/interface/PseudoMELA.h"
|
22 |
#include "ZZMatrixElement/MELA/src/computeAngles.h"
|
23 |
|
24 |
using namespace std;
|
25 |
|
26 |
void setTrainingVal(TString name, double *val, filestuff *fs);
|
27 |
void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi);
|
28 |
|
29 |
int main(int argc, char** argv) {
|
30 |
|
31 |
if( argc < 7 ) {
|
32 |
cerr << "usage: makeBDTWeights.exe <config> <outputrootfile> <weightname> <varstring> <mH> [wL] [wH] [mH_lo] [mH_hi]" << endl;
|
33 |
cerr << "\t... where [wL,wH] are the requested mass window and [mH_lo,mH_hi] are used for interpolation" << endl;
|
34 |
return 1;
|
35 |
}
|
36 |
cout << "argc: " << argc << endl;
|
37 |
for(int iarg=0; iarg<argc; iarg++)
|
38 |
cout << "argv[" << iarg << "]: " << argv[iarg] << endl;
|
39 |
|
40 |
char * config = argv[1];
|
41 |
cout << "config: " << config << endl;
|
42 |
|
43 |
// TString ntupledir;
|
44 |
vector<TString> classes,fileNames,types,indirs,dsets,wgtStrs;
|
45 |
vector<double> fileFracs,fileMasses;
|
46 |
vector<TFile*> inFiles;
|
47 |
vector<TTree*> inTrees;
|
48 |
vector<filestuff*> fsv;
|
49 |
bool multiclass(false);
|
50 |
bool multisigs(false);
|
51 |
bool noFakeTest(false);
|
52 |
bool useMela(false);
|
53 |
TString NTrees(""),NNodesMax(""),Shrinkage(""),nCuts(""),NormMode("");
|
54 |
int nMax=1000000;
|
55 |
ifstream ifs(config);
|
56 |
assert(ifs.is_open());
|
57 |
string line;
|
58 |
while(getline(ifs,line)) {
|
59 |
if(line[0]=='#') continue;
|
60 |
stringstream ss(line);
|
61 |
if(line[0]=='^') {
|
62 |
TString dummy;
|
63 |
// if(TString(line).Contains("^ntupledir")) ss >> dummy >> ntupledir;
|
64 |
if(TString(line).Contains("^multiclass")){ multiclass = true; cout << "multiclass" << multiclass << endl; }
|
65 |
if(TString(line).Contains("^multisigs" )){ multisigs = true; cout << "multisigs" << multisigs << endl; }
|
66 |
if(TString(line).Contains("^noFakeTest")){ noFakeTest = true; cout << "noFakeTest" << noFakeTest << endl; }
|
67 |
if(TString(line).Contains("^NTrees" )){ ss >> dummy >> NTrees; cout << "NTrees" << NTrees << endl; }
|
68 |
if(TString(line).Contains("^NNodesMax" )){ ss >> dummy >> NNodesMax; cout << "NNodesMax" << NNodesMax << endl; }
|
69 |
if(TString(line).Contains("^Shrinkage" )){ ss >> dummy >> Shrinkage; cout << "Shrinkage" << Shrinkage << endl; }
|
70 |
if(TString(line).Contains("^nCuts" )){ ss >> dummy >> nCuts; cout << "nCuts" << nCuts << endl; }
|
71 |
if(TString(line).Contains("^NormMode" )){ ss >> dummy >> NormMode; cout << "NormMode" << NormMode << endl; }
|
72 |
if(TString(line).Contains("^nMax" )){ ss >> dummy >> nMax; cout << "nMax" << nMax << endl; }
|
73 |
if(TString(line).Contains("^useMela" )){ useMela = true; cout << "useMela" << useMela << endl; }
|
74 |
continue;
|
75 |
}
|
76 |
|
77 |
TString cls,indir,type,dset,wgtStr;
|
78 |
double frac,fileMass;
|
79 |
bool isdata;
|
80 |
ss >> cls >> type >> frac >> wgtStr >> indir >> dset >> fileMass >> isdata;
|
81 |
// assert(cls=="sig" || cls=="sig2" || cls=="sig3" || cls=="bkg");
|
82 |
classes.push_back(cls);
|
83 |
types.push_back(type);
|
84 |
indirs.push_back(indir);
|
85 |
dsets.push_back(dset);
|
86 |
TString fname(indir+"/"+dset+"/merged.root");
|
87 |
fsv.push_back(new filestuff(dset, fname, dset, isdata, 2012));
|
88 |
fileNames.push_back(/*ntupledir+"/"+*/fname);
|
89 |
fileFracs.push_back(frac);
|
90 |
wgtStrs.push_back(wgtStr);
|
91 |
fileMasses.push_back(fileMass);
|
92 |
cout << " pushing back: " << cls << " " << setw(12) << frac << " " << fname << " " << fileMass << endl;
|
93 |
inFiles.push_back(TFile::Open(fileNames.back()));
|
94 |
if(!inFiles.back()->IsOpen()) {
|
95 |
cout << fileNames.back() << " not open!" << endl;
|
96 |
assert(0);
|
97 |
}
|
98 |
inTrees.push_back((TTree*)inFiles.back()->Get("zznt"));
|
99 |
assert(inTrees.back());
|
100 |
}
|
101 |
ifs.close();
|
102 |
|
103 |
char * outputrootfile = argv[2];
|
104 |
char * weightname = argv[3];
|
105 |
char * varstring = argv[4];
|
106 |
float mH = strtof(argv[5],NULL);
|
107 |
float wH=900, wL=100;
|
108 |
if( argc > 6 ) {
|
109 |
wL = strtof(argv[6], NULL);
|
110 |
wH = strtof(argv[7], NULL);
|
111 |
}
|
112 |
float mH_lo=0,mH_hi=0;
|
113 |
float frac_lo=1,frac_hi=1;
|
114 |
map<double,double> mH_fracs;
|
115 |
mH_fracs[-1] = 1; // this means zz
|
116 |
assert( argc > 8);
|
117 |
mH_lo = strtof(argv[8],NULL);
|
118 |
mH_hi = strtof(argv[9],NULL);
|
119 |
if(mH_lo!=0 && mH_hi!=0) {
|
120 |
// set the fraction to use for each input file, with more events from the closer mass point
|
121 |
// NOTE: we only use these fractions if the fraction from the config file is -1
|
122 |
setFractions(mH,mH_lo,mH_hi,frac_lo,frac_hi);
|
123 |
mH_fracs[mH_lo] = frac_lo;
|
124 |
mH_fracs[mH_hi] = frac_hi;
|
125 |
} else {
|
126 |
mH_fracs[mH] = 1;
|
127 |
}
|
128 |
|
129 |
cerr
|
130 |
<< "name: " << weightname << endl
|
131 |
<< "output: " << outputrootfile << endl
|
132 |
<< "variables: " << varstring << endl
|
133 |
<< "mH: " << mH << endl
|
134 |
<< "mass window:" << endl
|
135 |
<< " wL: " << wL << endl
|
136 |
<< " wH: " << wH << endl
|
137 |
<< "masses from which to interpolate:" << endl
|
138 |
<< " mH_lo: " << mH_lo << endl
|
139 |
<< " mH_hi: " << mH_hi << endl
|
140 |
<< "weights for interpolation:" << endl
|
141 |
<< " frac_lo: " << frac_lo << endl
|
142 |
<< " frac_hi: " << frac_hi << endl
|
143 |
<< endl;
|
144 |
|
145 |
// get the list of variables we're going to use
|
146 |
vector<double> varVals;
|
147 |
TString varStr(varstring);
|
148 |
varStr.ReplaceAll(":"," ");
|
149 |
vector<TString> vars,specs;
|
150 |
stringstream ss(varStr.Data());
|
151 |
while(!ss.eof()) {
|
152 |
TString var;
|
153 |
ss >> var;
|
154 |
vars.push_back(var);
|
155 |
cout << "VAR: " << var << endl;
|
156 |
varVals.push_back(double(0));
|
157 |
}
|
158 |
specs.push_back("m4l"); varVals.push_back(double(0));
|
159 |
specs.push_back("run"); varVals.push_back(double(0));
|
160 |
specs.push_back("lumi"); varVals.push_back(double(0));
|
161 |
specs.push_back("evt"); varVals.push_back(double(0));
|
162 |
assert(varVals.size() == (vars.size() + specs.size()));
|
163 |
|
164 |
TFile *outputFile = new TFile(outputrootfile, "RECREATE");
|
165 |
TMVA::Tools::Instance();
|
166 |
TString facStr("!V:!Silent:!DrawProgressBar");
|
167 |
if(multiclass) facStr = facStr+":AnalysisType=multiclass";
|
168 |
TMVA::Factory* factory = new TMVA::Factory(weightname, outputFile, facStr);
|
169 |
for(unsigned ivar=0; ivar<vars.size(); ivar++) factory->AddVariable( vars[ivar], 'F');
|
170 |
for(unsigned ispec=0; ispec<specs.size(); ispec++) factory->AddSpectator(specs[ispec], 'F');
|
171 |
|
172 |
char buf[256];
|
173 |
sprintf( buf, "(m4l>%f && m4l<%f)", (float)wL, (float)wH );
|
174 |
TString cutStr(buf);
|
175 |
cout << "cutStr: " << cutStr << endl;
|
176 |
|
177 |
bool useWgts(true);
|
178 |
vector<vector<TString> > files;
|
179 |
vector<double> passFracs,nPassCutv,baseWgts;
|
180 |
cout << setw(13) << "class" << setw(6) << "type" << setw(12) << " nPass " << setw(8) << "nTot" << setw(10) << " frac" << setw(12) << "baseWgt" << endl;
|
181 |
for(unsigned ifile=0; ifile<fileNames.size(); ifile++) {
|
182 |
if(fileFracs[ifile] == -1) fileFracs[ifile] = mH_fracs[fileMasses[ifile]];
|
183 |
double nPass(0),nTot(0),passFrac(0);
|
184 |
if(useWgts) {
|
185 |
TH1D hPass("hPass","",100,0,1000);
|
186 |
TH1D hTot("hTot","",100,0,1000);
|
187 |
TString wgtStr(wgtStrs[ifile].Contains("wInterf") ? "1" : wgtStrs[ifile]); // tree doesn't have a wInterf leaf
|
188 |
inTrees[ifile]->Draw("m4l>>hPass",wgtStr+"*"+cutStr);
|
189 |
inTrees[ifile]->Draw("m4l>>hTot",wgtStr);
|
190 |
passFrac = hPass.Integral() / hTot.Integral();
|
191 |
nPass = hPass.Integral();
|
192 |
nTot = hTot.Integral();
|
193 |
} else {
|
194 |
nPass = double(inTrees[ifile]->GetEntries(cutStr));
|
195 |
passFrac = nPass / inTrees[ifile]->GetEntries();
|
196 |
nTot = inTrees[ifile]->GetEntries();
|
197 |
}
|
198 |
passFracs.push_back(passFrac);
|
199 |
nPassCutv.push_back(nPass);
|
200 |
double baseWgt(fileFracs[ifile] / nPass);
|
201 |
// if(fsv[ifile]->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest)
|
202 |
// baseWgt /= 2; // if we're putting all the fakes into training, divide weights by two
|
203 |
baseWgts.push_back(baseWgt);
|
204 |
cout << setw(13) << classes[ifile] << "(" << setw(5) << types[ifile] << "): " << setw(10) << nPass << " / " << setw(12) << nTot
|
205 |
<< setw(6) << setprecision(5) << fileFracs[ifile] << setw(15) << baseWgt
|
206 |
<< setw(30) << dsets[ifile] << " " << indirs[ifile]
|
207 |
<< endl;
|
208 |
}
|
209 |
|
210 |
Mela *mela(0);
|
211 |
if(useMela) mela = new Mela(false, 8);
|
212 |
|
213 |
for(unsigned ifs=0; ifs<fsv.size(); ifs++) {
|
214 |
filestuff *fs = fsv[ifs];
|
215 |
cout << "starting: " << fs->dataset_ << endl;
|
216 |
double tot(0);
|
217 |
for(unsigned ientry=0; ientry<fs->getentries("zznt"); ientry++) {
|
218 |
fs->getentry(ientry,"kinematics","zznt");
|
219 |
if(fs->kine->m4l < wL || fs->kine->m4l > wH) continue;
|
220 |
fs->getentry(ientry,"","zznt");
|
221 |
// if(ientry<100)
|
222 |
// cout << fs->ji->nJets << endl; // wasn't filled correctly for some reason
|
223 |
|
224 |
for(unsigned ivar=0; ivar<vars.size(); ivar++)
|
225 |
setTrainingVal(vars[ivar], &varVals[ivar], fs);
|
226 |
for(unsigned ispec=vars.size(); ispec<varVals.size(); ispec++)
|
227 |
setTrainingVal(specs[ispec-vars.size()], &varVals[ispec], fs);
|
228 |
|
229 |
// train on even events
|
230 |
// if noFakeTest train on *all* the lljj events since we don't have many to start with
|
231 |
double wgt(baseWgts[ifs]);
|
232 |
if(wgtStrs[ifs] == "w") wgt *= fs->weights->w;
|
233 |
else if(wgtStrs[ifs] == "1") wgt *= 1;
|
234 |
// else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs, mela, fs->dataset_.Contains("-XXXdkr"), mH);
|
235 |
else if(wgtStrs[ifs] == "wInterf") wgt *= getInterferenceWeight(fs, mela);
|
236 |
else assert(0);
|
237 |
if(wgt<.000000001 || wgt>5) { cout << "bad wgt?" << wgt << endl;}
|
238 |
// if(fs->dataset_.Contains("-dkr") // && ientry<100
|
239 |
// )
|
240 |
// // cout << "test wgt: " << getInterferenceWeight(fs, mela, fs->dataset_.Contains("-XXXdkr"), mH) << endl;
|
241 |
// // cout << "test wgt: " << getInterferenceWeight(fs, mela) << endl;
|
242 |
|
243 |
if((ientry%2)==0 || (fs->dataset_.Contains("fakes",TString::kIgnoreCase) && noFakeTest))
|
244 |
factory->AddEvent(classes[ifs], TMVA::Types::kTraining, varVals, wgt);
|
245 |
if((ientry%2)!=0)
|
246 |
factory->AddEvent(classes[ifs], TMVA::Types::kTesting, varVals, wgt);
|
247 |
|
248 |
tot += wgt;
|
249 |
}
|
250 |
cout << " " << tot << endl;
|
251 |
}
|
252 |
|
253 |
if(NTrees=="") NTrees = "10000";
|
254 |
if(NNodesMax=="") NNodesMax = "10";
|
255 |
if(Shrinkage=="") Shrinkage = "0.05";
|
256 |
if(nCuts=="") nCuts = "20";
|
257 |
if(NormMode=="") NormMode = "EqualNumEvents";
|
258 |
assert(NormMode=="NumEvents" || NormMode=="EqualNumEvents");
|
259 |
factory->PrepareTrainingAndTestTree( "", "SplitMode=Random:V:NormMode="+NormMode);
|
260 |
factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:V:NTrees="+NTrees
|
261 |
+":NNodesMax="+NNodesMax
|
262 |
+":BoostType=Grad:Shrinkage="+Shrinkage
|
263 |
+":UseBaggedGrad:GradBaggingFraction=0.50:nCuts="+nCuts
|
264 |
// +":nTrain_Signal=100:nTest_Signal=100:nTrain_Background=100:nTest_Background=100:nTrain_Signal2=100:nTest_Signal2=100"
|
265 |
+":IgnoreNegWeights");
|
266 |
|
267 |
factory->TrainAllMethods();
|
268 |
factory->TestAllMethods();
|
269 |
factory->EvaluateAllMethods();
|
270 |
|
271 |
std::cerr << "closing ... " << std::endl;
|
272 |
outputFile->Close();
|
273 |
}
|
274 |
//----------------------------------------------------------------------------------------
|
275 |
void setFractions(float mH, float mH_lo, float mH_hi, float &frac_lo, float &frac_hi)
|
276 |
{
|
277 |
double delta = mH_hi - mH_lo;
|
278 |
frac_lo = (mH_hi - mH)/delta;
|
279 |
frac_hi = (mH - mH_lo)/delta;
|
280 |
}
|
281 |
//----------------------------------------------------------------------------------------
|
282 |
void setTrainingVal(TString name, double *val, filestuff *fs)
|
283 |
{
|
284 |
if(name=="costheta1") *val = fs->angles->costheta1;
|
285 |
else if(name=="costheta2") *val = fs->angles->costheta2;
|
286 |
else if(name=="costhetastar") *val = fs->angles->costhetastar;
|
287 |
else if(name=="Phi") *val = fs->angles->Phi;
|
288 |
else if(name=="Phi1") *val = fs->angles->Phi1;
|
289 |
else if(name=="mZ1") *val = fs->kine->mZ1;
|
290 |
else if(name=="mZ2") *val = fs->kine->mZ2;
|
291 |
// pt variables
|
292 |
else if(name=="ZZpt") *val = fs->kine->ZZpt/fs->kine->m4l;
|
293 |
else if(name=="ZZdotZ1") *val = fs->kine->ZZdotZ1/(fs->kine->m4l*fs->kine->mZ1);
|
294 |
else if(name=="ZZdotZ2") *val = fs->kine->ZZdotZ2/(fs->kine->m4l*fs->kine->mZ2);
|
295 |
else if(name=="ZZptCosDphiZ1pt") *val = fs->kine->ZZptZ1ptCosDphi;
|
296 |
else if(name=="ZZptCosDphiZ2pt") *val = fs->kine->ZZptZ2ptCosDphi;
|
297 |
else if(name=="Z1pt") *val = fs->kine->Z1pt/fs->kine->m4l;
|
298 |
else if(name=="Z2pt") *val = fs->kine->Z2pt/fs->kine->m4l;
|
299 |
else if(name=="ZZy") *val = fs->kine->ZZy;
|
300 |
// jet variables
|
301 |
else if(name=="nJets" ) *val = fs->vk->getval("nJets");
|
302 |
else if(name=="mjj" ) *val = fs->vk->getval("mjj");
|
303 |
else if(name=="dEta" ) *val = fs->vk->getval("dEta");
|
304 |
else if(name=="etaProd" ) *val = fs->vk->getval("etaProd");
|
305 |
else if(name=="dphiJ1HiPtZ" ) *val = fs->vk->getval("dphiJ1HiPtZ");
|
306 |
else if(name=="dphiJ1LoPtZ" ) *val = fs->vk->getval("dphiJ1LoPtZ");
|
307 |
else if(name=="dEtaJ1HiPtZ" ) *val = fs->vk->getval("dEtaJ1HiPtZ");
|
308 |
else if(name=="dEtaJ1LoPtZ" ) *val = fs->vk->getval("dEtaJ1LoPtZ");
|
309 |
else if(name=="J1dotHiPtZ" ) *val = fs->vk->getval("J1dotHiPtZ");
|
310 |
else if(name=="J1dotLoPtZ" ) *val = fs->vk->getval("J1dotLoPtZ");
|
311 |
else if(name=="dphiJ2HiPtZ" ) *val = fs->vk->getval("dphiJ2HiPtZ");
|
312 |
else if(name=="dphiJ2LoPtZ" ) *val = fs->vk->getval("dphiJ2LoPtZ");
|
313 |
else if(name=="dEtaJ2HiPtZ" ) *val = fs->vk->getval("dEtaJ2HiPtZ");
|
314 |
else if(name=="dEtaJ2LoPtZ" ) *val = fs->vk->getval("dEtaJ2LoPtZ");
|
315 |
else if(name=="J2dotHiPtZ" ) *val = fs->vk->getval("J2dotHiPtZ");
|
316 |
else if(name=="J2dotLoPtZ" ) *val = fs->vk->getval("J2dotLoPtZ");
|
317 |
// spectators
|
318 |
else if(name=="m4l") *val = fs->kine->m4l;
|
319 |
else if(name=="run") *val = fs->info->run;
|
320 |
else if(name=="lumi") *val = fs->info->lumi;
|
321 |
else if(name=="evt") *val = fs->info->evt;
|
322 |
else { cout << name << " not found!" << endl; assert(0); }
|
323 |
}
|