ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/VHbb/python/myutils/RegressionTrainer.py
Revision: 1.2
Committed: Tue Mar 12 19:45:14 2013 UTC (12 years, 2 months ago) by nmohr
Content type: text/x-python
Branch: MAIN
CVS Tags: lhcp_UnblindFix, hcp_Unblind, lhcp_11April, LHCP_PreAppFixAfterFreeze, LHCP_PreAppFreeze, HEAD
Changes since 1.1: +1 -1 lines
Log Message:
Silence output

File Contents

# Content
1 import sys,re,ROOT
2 ROOT.gROOT.SetBatch(True)
3 from sample_parser import ParseInfo
4 from TreeCache import TreeCache
5
6 class RegressionTrainer():
7 def __init__(self, config):
8 vhbb_name_space = config.get('VHbbNameSpace','library')
9 ROOT.gSystem.Load(vhbb_name_space)
10
11 self.__weight = config.get("TrainRegression","weight")
12 self.__vars = config.get("TrainRegression","vars").split()
13 self.__target = config.get("TrainRegression","target")
14 self.__cut = config.get("TrainRegression","cut")
15 self.__title = config.get("TrainRegression","name")
16 self.__signals = config.get("TrainRegression","signals")
17 self.__regOptions = config.get("TrainRegression","options")
18 path = config.get('Directories','PREPout')
19 samplesinfo=config.get('Directories','samplesinfo')
20 self.__info = ParseInfo(samplesinfo,path)
21 self.__samples = self.__info.get_samples(self.__signals)
22 self.__tc = TreeCache([self.__cut],self.__samples,path,config)
23 self.__trainCut = config.get("TrainRegression","trainCut")
24 self.__testCut = config.get("TrainRegression","testCut")
25 self.__config = config
26
27
28 def train(self):
29 signals = []
30 signalsTest = []
31 for job in self.__samples:
32 print '\tREADING IN %s AS SIG'%job.name
33 signals.append(self.__tc.get_tree(job,'%s & %s' %(self.__cut,self.__trainCut)))
34 signalsTest.append(self.__tc.get_tree(job,'%s & %s'%(self.__cut,self.__testCut)))
35
36 sWeight = 1.
37 fnameOutput='training_Reg_%s.root' %(self.__title)
38 output = ROOT.TFile.Open(fnameOutput, "RECREATE")
39
40 factory = ROOT.TMVA.Factory('MVA', output, '!V:!Silent:!Color:!DrawProgressBar:Transformations=I:AnalysisType=Regression')
41 #factory.SetSignalWeightExpression( self.__weight )
42 #set input trees
43 for i, signal in enumerate(signals):
44 factory.AddRegressionTree( signal, sWeight, ROOT.TMVA.Types.kTraining)
45 factory.AddRegressionTree( signalsTest[i], sWeight, ROOT.TMVA.Types.kTesting)
46 self.__apply = []
47 p = re.compile(r'hJet_\w+')
48 for var in self.__vars:
49 factory.AddVariable(var,'D') # add the variables
50 self.__apply.append(p.sub(r'\g<0>[0]', var))
51
52 factory.AddTarget( self.__target )
53 mycut = ROOT.TCut( self.__cut )
54 factory.BookMethod(ROOT.TMVA.Types.kBDT,'BDT_REG_%s'%(self.__title),self.__regOptions) # book an MVA method
55 factory.TrainAllMethods()
56 factory.TestAllMethods()
57 factory.EvaluateAllMethods()
58 output.Write()
59 regDict = dict(zip(self.__vars, self.__apply))
60 self.__config.set('Regression', 'regWeight', '../data/MVA_BDT_REG_%s.weights.xml' %self.__title)
61 self.__config.set('Regression', 'regDict', '%s' %regDict)
62 self.__config.set('Regression', 'regVars', '%s' %self.__vars)
63 for section in self.__config.sections():
64 if not section == 'Regression':
65 self.__config.remove_section(section)
66 with open('8TeVconfig/appReg', 'w') as configfile:
67 self.__config.write(configfile)
68 with open('8TeVconfig/appReg', 'r') as configfile:
69 for line in configfile:
70 print line.strip()
71