EIC Software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CharmJetClassification.C
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file CharmJetClassification.C
1 #include "TROOT.h"
2 #include "TChain.h"
3 #include "TFile.h"
4 #include "TEfficiency.h"
5 #include "TH1.h"
6 #include "TGraphErrors.h"
7 #include "TCut.h"
8 #include "TCanvas.h"
9 #include "TStyle.h"
10 #include "TLegend.h"
11 #include "TMath.h"
12 #include "TLine.h"
13 #include "TLatex.h"
14 #include "TRatioPlot.h"
15 
16 #include "TMVA/Factory.h"
17 #include "TMVA/DataLoader.h"
18 #include "TMVA/Tools.h"
19 
20 #include <glob.h>
21 #include <iostream>
22 #include <iomanip>
23 #include <vector>
24 
25 #include "PlotFunctions.h"
26 
27 
28 void CharmJetClassification(TString dir, TString input, TString filePattern = "*/out.root")
29 {
30  // Global options
31  gStyle->SetOptStat(0);
32 
33  // Create the TCanvas
34  TCanvas *pad = new TCanvas("pad",
35  "",
36  800,
37  600);
38  TLegend *legend = nullptr;
39  TH1F *htemplate = nullptr;
40 
41  auto default_data = new TChain("tree");
42  default_data->SetTitle(input.Data());
43  auto files = fileVector(Form("%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
44 
45  for (auto file : files)
46  {
47  default_data->Add(file.c_str());
48  }
49 
50  // Create the signal and background trees
51 
52  auto signal_train = default_data->CopyTree("jet_flavor==4 && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
53  std::cout << "Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
54 
55  auto background_train = default_data->CopyTree("(jet_flavor<4||jet_flavor==21) && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
56  std::cout << "Background Tree (Training): " << background_train->GetEntries() << std::endl;
57 
58  // auto u_background_train = default_data->CopyTree("(jet_flavor==2)", "",
59  // TMath::Floor(default_data->GetEntries()/1.0));
60  // std::cout << "u Background Tree (Training): " <<
61  // u_background_train->GetEntries() << std::endl;
62 
63  // auto d_background_train = default_data->CopyTree("(jet_flavor==1)", "",
64  // TMath::Floor(default_data->GetEntries()/1.0));
65  // std::cout << "d Background Tree (Training): " <<
66  // d_background_train->GetEntries() << std::endl;
67 
68  // auto g_background_train = default_data->CopyTree("(jet_flavor==21)", "",
69  // TMath::Floor(default_data->GetEntries()/1.0));
70  // std::cout << "g Background Tree (Training): " <<
71  // g_background_train->GetEntries() << std::endl;
72 
73  // Create the TMVA tools
74  TMVA::Tools::Instance();
75 
76  auto outputFile = TFile::Open("CharmJetClassification_Results.root", "RECREATE");
77 
78  TMVA::Factory factory("TMVAClassification",
79  outputFile,
80  "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification");
81 
82  //
83  // Kaon Tagger
84  //
85 
86  TMVA::DataLoader loader_ktagger("dataset_ktagger");
87  TMVA::DataLoader loader_etagger("dataset_etagger");
88  TMVA::DataLoader loader_mutagger("dataset_mutagger");
89  TMVA::DataLoader loader_ip3dtagger("dataset_ip3dtagger");
90 
91  // loader_ktagger.AddVariable("jet_pt", "Jet p_{T}", "GeV", 'F',
92  // 0.0,
93  // 1000.0);
94  // loader_ktagger.AddVariable("jet_eta", "Jet Pseudorapidity", "", 'F',
95  // -5.0, 5.0);
96 
97  loader_ktagger.AddSpectator("jet_pt");
98  loader_ktagger.AddSpectator("jet_eta");
99  loader_ktagger.AddVariable("jet_k1_pt");
100  loader_ktagger.AddVariable("jet_k1_sIP3D");
101  loader_ktagger.AddVariable("jet_k2_pt");
102  loader_ktagger.AddVariable("jet_k2_sIP3D");
103  loader_ktagger.AddSpectator("jet_flavor");
104  loader_ktagger.AddSpectator("met_et");
105  loader_ktagger.AddSignalTree(signal_train, 1.0);
106  loader_ktagger.AddBackgroundTree(background_train, 1.0);
107 
108  // loader_ktagger.AddVariable("jet_nconstituents");
109 
110  loader_etagger.AddSpectator("jet_pt");
111  loader_etagger.AddSpectator("jet_eta");
112  loader_etagger.AddVariable("jet_e1_pt");
113  loader_etagger.AddVariable("jet_e1_sIP3D");
114  loader_etagger.AddVariable("jet_e2_pt");
115  loader_etagger.AddVariable("jet_e2_sIP3D");
116  loader_etagger.AddSpectator("jet_flavor");
117  loader_etagger.AddSpectator("met_et");
118  loader_etagger.AddSignalTree(signal_train, 1.0);
119  loader_etagger.AddBackgroundTree(background_train, 1.0);
120 
121  loader_mutagger.AddSpectator("jet_pt");
122  loader_mutagger.AddSpectator("jet_eta");
123  loader_mutagger.AddVariable("jet_mu1_pt");
124  loader_mutagger.AddVariable("jet_mu1_sIP3D");
125  loader_mutagger.AddVariable("jet_mu2_pt");
126  loader_mutagger.AddVariable("jet_mu2_sIP3D");
127  loader_mutagger.AddSpectator("jet_flavor");
128  loader_mutagger.AddSpectator("met_et");
129  loader_mutagger.AddSignalTree(signal_train, 1.0);
130  loader_mutagger.AddBackgroundTree(background_train, 1.0);
131 
132 
133  loader_ip3dtagger.AddSpectator("jet_pt");
134  loader_ip3dtagger.AddSpectator("jet_eta");
135  loader_ip3dtagger.AddVariable("jet_t1_pt");
136  loader_ip3dtagger.AddVariable("jet_t1_sIP3D");
137  loader_ip3dtagger.AddVariable("jet_t2_pt");
138  loader_ip3dtagger.AddVariable("jet_t2_sIP3D");
139  loader_ip3dtagger.AddVariable("jet_t3_pt");
140  loader_ip3dtagger.AddVariable("jet_t3_sIP3D");
141  loader_ip3dtagger.AddVariable("jet_t4_pt");
142  loader_ip3dtagger.AddVariable("jet_t4_sIP3D");
143  loader_ip3dtagger.AddSpectator("jet_flavor");
144  loader_ip3dtagger.AddSpectator("met_et");
145  loader_ip3dtagger.AddSignalTree(signal_train, 1.0);
146  loader_ip3dtagger.AddBackgroundTree(background_train, 1.0);
147 
148  // loader.AddVariable("jet_sip3dtag", "sIP3D Jet-Level Tag", "", 'B', -10,
149  // 10);
150  // loader.AddVariable("jet_charge");
151 
152  // loader.AddVariable("jet_ehadoveremratio");
153 
154 
155  // loader.AddTree( signal_train, "strange_jets" );
156  // loader.AddTree( u_background_train, "up jets" );
157  // loader.AddTree( d_background_train, "down jets" );
158  // loader.AddTree( g_background_train, "gluon jets" );
159 
160  loader_ktagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
161  TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
162  "nTrain_Signal=50000:nTrain_Background=500000:nTest_Signal=50000:nTest_Background=500000:SplitMode=Random:NormMode=NumEvents:!V");
163  loader_etagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
164  TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
165  "nTrain_Signal=100000:nTrain_Background=1000000:nTest_Signal=100000:nTest_Background=1000000:SplitMode=Random:NormMode=NumEvents:!V");
166  loader_mutagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
167  TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
168  "nTrain_Signal=100000:nTrain_Background=1000000:nTest_Signal=100000:nTest_Background=1000000:SplitMode=Random:NormMode=NumEvents:!V");
169  loader_ip3dtagger.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
170  TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
171  "nTrain_Signal=10000:nTrain_Background=100000:nTest_Signal=10000:nTest_Background=100000:SplitMode=Random:NormMode=NumEvents:!V");
172 
173  // Declare the classification method(s)
174  // factory.BookMethod(&loader,TMVA::Types::kBDT, "BDT",
175  //
176  //
177  // "!V:NTrees=1000:MinNodeSize=2.5%:MaxDepth=4:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20"
178  // );
179  factory.BookMethod(&loader_ktagger, TMVA::Types::kMLP, "CharmKTagger",
180  "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+12:TestRate=5:!UseRegulator");
181  // factory.BookMethod(&loader_etagger, TMVA::Types::kMLP, "CharmETagger",
182  // "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+12:TestRate=5:!UseRegulator");
183  // factory.BookMethod(&loader_mutagger, TMVA::Types::kMLP, "CharmMuTagger",
184  // "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+12:TestRate=5:!UseRegulator");
185  // factory.BookMethod(&loader_ip3dtagger, TMVA::Types::kMLP, "CharmIP3DTagger",
186  // "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+16:TestRate=5:!UseRegulator");
187 
188  // factory.BookMethod(&loader, TMVA::Types::kMLP,
189  // "CharmETagger","!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=600:HiddenLayers=N+8:TestRate=5:!UseRegulator");
190  // factory.BookMethod(&loader, TMVA::Types::kMLP,
191  // "CharmMuTagger","!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=600:HiddenLayers=N+8:TestRate=5:!UseRegulator");
192 
193  // Train
194  factory.TrainAllMethods();
195 
196  // Test
197  factory.TestAllMethods();
198  factory.EvaluateAllMethods();
199 
200  // Plot a ROC Curve
201  pad->cd();
202  pad = factory.GetROCCurve(&loader_ktagger);
203  // pad = factory.GetROCCurve(&loader_etagger);
204  // pad = factory.GetROCCurve(&loader_mutagger);
205  // pad = factory.GetROCCurve(&loader_ip3dtagger);
206  pad->Draw();
207 
208  pad->SaveAs("CharmJetClassification_ROC.pdf");
209 }