ViewVC Help
View File | Revision Log | Show Annotations | Root Listing
root/cvsroot/UserCode/VHbb/python/myutils/RegressionTrainer.py
Revision: 1.2
Committed: Tue Mar 12 19:45:14 2013 UTC (12 years, 2 months ago) by nmohr
Content type: text/x-python
Branch: MAIN
CVS Tags: lhcp_UnblindFix, hcp_Unblind, lhcp_11April, LHCP_PreAppFixAfterFreeze, LHCP_PreAppFreeze, HEAD
Changes since 1.1: +1 -1 lines
Log Message:
Silence output

File Contents

# User Rev Content
1 nmohr 1.1 import sys,re,ROOT
2     ROOT.gROOT.SetBatch(True)
3     from sample_parser import ParseInfo
4     from TreeCache import TreeCache
5    
6     class RegressionTrainer():
7     def __init__(self, config):
8     vhbb_name_space = config.get('VHbbNameSpace','library')
9     ROOT.gSystem.Load(vhbb_name_space)
10    
11     self.__weight = config.get("TrainRegression","weight")
12     self.__vars = config.get("TrainRegression","vars").split()
13     self.__target = config.get("TrainRegression","target")
14     self.__cut = config.get("TrainRegression","cut")
15     self.__title = config.get("TrainRegression","name")
16     self.__signals = config.get("TrainRegression","signals")
17     self.__regOptions = config.get("TrainRegression","options")
18     path = config.get('Directories','PREPout')
19     samplesinfo=config.get('Directories','samplesinfo')
20     self.__info = ParseInfo(samplesinfo,path)
21     self.__samples = self.__info.get_samples(self.__signals)
22     self.__tc = TreeCache([self.__cut],self.__samples,path,config)
23     self.__trainCut = config.get("TrainRegression","trainCut")
24     self.__testCut = config.get("TrainRegression","testCut")
25     self.__config = config
26    
27    
28     def train(self):
29     signals = []
30     signalsTest = []
31     for job in self.__samples:
32     print '\tREADING IN %s AS SIG'%job.name
33     signals.append(self.__tc.get_tree(job,'%s & %s' %(self.__cut,self.__trainCut)))
34     signalsTest.append(self.__tc.get_tree(job,'%s & %s'%(self.__cut,self.__testCut)))
35    
36     sWeight = 1.
37     fnameOutput='training_Reg_%s.root' %(self.__title)
38     output = ROOT.TFile.Open(fnameOutput, "RECREATE")
39    
40 nmohr 1.2 factory = ROOT.TMVA.Factory('MVA', output, '!V:!Silent:!Color:!DrawProgressBar:Transformations=I:AnalysisType=Regression')
41 nmohr 1.1 #factory.SetSignalWeightExpression( self.__weight )
42     #set input trees
43     for i, signal in enumerate(signals):
44     factory.AddRegressionTree( signal, sWeight, ROOT.TMVA.Types.kTraining)
45     factory.AddRegressionTree( signalsTest[i], sWeight, ROOT.TMVA.Types.kTesting)
46     self.__apply = []
47     p = re.compile(r'hJet_\w+')
48     for var in self.__vars:
49     factory.AddVariable(var,'D') # add the variables
50     self.__apply.append(p.sub(r'\g<0>[0]', var))
51    
52     factory.AddTarget( self.__target )
53     mycut = ROOT.TCut( self.__cut )
54     factory.BookMethod(ROOT.TMVA.Types.kBDT,'BDT_REG_%s'%(self.__title),self.__regOptions) # book an MVA method
55     factory.TrainAllMethods()
56     factory.TestAllMethods()
57     factory.EvaluateAllMethods()
58     output.Write()
59     regDict = dict(zip(self.__vars, self.__apply))
60     self.__config.set('Regression', 'regWeight', '../data/MVA_BDT_REG_%s.weights.xml' %self.__title)
61     self.__config.set('Regression', 'regDict', '%s' %regDict)
62     self.__config.set('Regression', 'regVars', '%s' %self.__vars)
63     for section in self.__config.sections():
64     if not section == 'Regression':
65     self.__config.remove_section(section)
66     with open('8TeVconfig/appReg', 'w') as configfile:
67     self.__config.write(configfile)
68     with open('8TeVconfig/appReg', 'r') as configfile:
69     for line in configfile:
70     print line.strip()
71