EIC Software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CharmJetGlobalTagger.C
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file CharmJetGlobalTagger.C
1 #include "TROOT.h"
2 #include "TChain.h"
3 #include "TFile.h"
4 #include "TEfficiency.h"
5 #include "TH1.h"
6 #include "TGraphErrors.h"
7 #include "TCut.h"
8 #include "TCanvas.h"
9 #include "TStyle.h"
10 #include "TLegend.h"
11 #include "TMath.h"
12 #include "TLine.h"
13 #include "TLatex.h"
14 #include "TRatioPlot.h"
15 
16 #include "TMVA/Factory.h"
17 #include "TMVA/DataLoader.h"
18 #include "TMVA/Tools.h"
19 
20 #include <glob.h>
21 #include <iostream>
22 #include <iomanip>
23 #include <vector>
24 
25 #include "PlotFunctions.h"
26 
27 
28 void CharmJetGlobalTagger(TString dir, TString input, TString filePattern = "*/out.root")
29 {
30  // Global options
31  gStyle->SetOptStat(0);
32 
33  // Create the TCanvas
34  TCanvas *pad = new TCanvas("pad",
35  "",
36  800,
37  600);
38  TLegend *legend = nullptr;
39  TH1F *htemplate = nullptr;
40 
41  auto default_data = new TChain("tree");
42  default_data->SetTitle(input.Data());
43  auto files = fileVector(Form("%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
44 
45  for (auto file : files)
46  {
47  default_data->Add(file.c_str());
48  }
49 
50  // Create the signal and background trees
51 
52  auto signal_train = default_data->CopyTree("jet_flavor==4 && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
53  std::cout << "Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
54 
55  auto background_train = default_data->CopyTree("(jet_flavor<4||jet_flavor==21) && jet_n>0", "", TMath::Floor(default_data->GetEntries() / 1.0));
56  std::cout << "Background Tree (Training): " << background_train->GetEntries() << std::endl;
57 
58  // auto u_background_train = default_data->CopyTree("(jet_flavor==2)", "",
59  // TMath::Floor(default_data->GetEntries()/1.0));
60  // std::cout << "u Background Tree (Training): " <<
61  // u_background_train->GetEntries() << std::endl;
62 
63  // auto d_background_train = default_data->CopyTree("(jet_flavor==1)", "",
64  // TMath::Floor(default_data->GetEntries()/1.0));
65  // std::cout << "d Background Tree (Training): " <<
66  // d_background_train->GetEntries() << std::endl;
67 
68  // auto g_background_train = default_data->CopyTree("(jet_flavor==21)", "",
69  // TMath::Floor(default_data->GetEntries()/1.0));
70  // std::cout << "g Background Tree (Training): " <<
71  // g_background_train->GetEntries() << std::endl;
72 
73  // Create the TMVA tools
74  TMVA::Tools::Instance();
75 
76  auto outputFile = TFile::Open("CharmJetGlobalTagger_Results.root", "RECREATE");
77 
78  TMVA::Factory factory("TMVAClassification",
79  outputFile,
80  "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification");
81 
82  //
83  // Kaon Tagger
84  //
85 
86  TMVA::DataLoader loader("dataset");
87 
88  // loader.AddVariable("jet_pt", "Jet p_{T}", "GeV", 'F',
89  // 0.0,
90  // 1000.0);
91  // loader.AddVariable("jet_eta", "Jet Pseudorapidity", "", 'F',
92  // -5.0, 5.0);
93 
94  loader.AddSpectator("jet_pt");
95  loader.AddSpectator("jet_eta");
96  loader.AddSpectator("jet_flavor");
97  loader.AddSpectator("met_et");
98  loader.AddVariable("jet_mlp_ip3dtagger");
99  loader.AddVariable("jet_mlp_ktagger");
100  loader.AddVariable("jet_mlp_eltagger");
101  loader.AddVariable("jet_mlp_mutagger");
102  loader.AddSignalTree(signal_train, 1.0);
103  loader.AddBackgroundTree(background_train, 1.0);
104 
105 
106  // loader.AddVariable("jet_sip3dtag", "sIP3D Jet-Level Tag", "", 'B', -10,
107  // 10);
108  // loader.AddVariable("jet_charge");
109 
110  // loader.AddVariable("jet_ehadoveremratio");
111 
112 
113  // loader.AddTree( signal_train, "strange_jets" );
114  // loader.AddTree( u_background_train, "up jets" );
115  // loader.AddTree( d_background_train, "down jets" );
116  // loader.AddTree( g_background_train, "gluon jets" );
117 
118  loader.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==4 && met_et > 10"),
119  TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<4||jet_flavor==21) && met_et > 10"),
120  "nTrain_Signal=50000:nTrain_Background=500000:nTest_Signal=50000:nTest_Background=500000:SplitMode=Random:NormMode=NumEvents:!V");
121 
122  // Declare the classification method(s)
123  // factory.BookMethod(&loader,TMVA::Types::kBDT, "BDT",
124  //
125  //
126  // "!V:NTrees=1000:MinNodeSize=2.5%:MaxDepth=4:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20"
127  // );
128  factory.BookMethod(&loader, TMVA::Types::kMLP, "CharmGlobalTagger",
129  "!H:!V:NeuronType=ReLU:VarTransform=Norm:NCycles=1000:HiddenLayers=N+8:TestRate=5:!UseRegulator");
130 
131  // Train
132  factory.TrainAllMethods();
133 
134  // Test
135  factory.TestAllMethods();
136  factory.EvaluateAllMethods();
137 
138  // Plot a ROC Curve
139  pad->cd();
140  pad = factory.GetROCCurve(&loader);
141  pad->Draw();
142 
143  pad->SaveAs("CharmJetGlobalTagger_ROC.pdf");
144 }