EIC Software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StrangeJetClassification.C
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file StrangeJetClassification.C
1 #include "TROOT.h"
2 #include "TChain.h"
3 #include "TFile.h"
4 #include "TEfficiency.h"
5 #include "TH1.h"
6 #include "TGraphErrors.h"
7 #include "TCut.h"
8 #include "TCanvas.h"
9 #include "TStyle.h"
10 #include "TLegend.h"
11 #include "TMath.h"
12 #include "TLine.h"
13 #include "TLatex.h"
14 #include "TRatioPlot.h"
15 
16 #include "TMVA/Factory.h"
17 #include "TMVA/DataLoader.h"
18 #include "TMVA/Tools.h"
19 
20 #include <glob.h>
21 #include <iostream>
22 #include <iomanip>
23 #include <vector>
24 
25 #include "PlotFunctions.h"
26 
27 
28 void StrangeJetClassification(TString dir, TString input, TString filePattern = "*/out.root")
29 {
30  // Global options
31  gStyle->SetOptStat(0);
32 
33  // Create the TCanvas
34  TCanvas* pad = new TCanvas("pad","",800,600);
35  TLegend * legend = nullptr;
36  TH1F* htemplate = nullptr;
37 
38  auto default_data = new TChain("tree");
39  default_data->SetTitle(input.Data());
40  auto files = fileVector(Form("%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
41 
42  for (auto file : files)
43  {
44  default_data->Add(file.c_str());
45  }
46 
47  // Create the signal and background trees
48 
49  auto signal_train = default_data->CopyTree("jet_flavor==3", "", TMath::Floor(default_data->GetEntries()/1.0));
50  std::cout << "Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
51 
52  auto background_train = default_data->CopyTree("(jet_flavor<3||jet_flavor==21)", "", TMath::Floor(default_data->GetEntries()/1.0));
53  std::cout << "Background Tree (Training): " << background_train->GetEntries() << std::endl;
54 
55  // auto u_background_train = default_data->CopyTree("(jet_flavor==2)", "", TMath::Floor(default_data->GetEntries()/1.0));
56  // std::cout << "u Background Tree (Training): " << u_background_train->GetEntries() << std::endl;
57 
58  // auto d_background_train = default_data->CopyTree("(jet_flavor==1)", "", TMath::Floor(default_data->GetEntries()/1.0));
59  // std::cout << "d Background Tree (Training): " << d_background_train->GetEntries() << std::endl;
60 
61  // auto g_background_train = default_data->CopyTree("(jet_flavor==21)", "", TMath::Floor(default_data->GetEntries()/1.0));
62  // std::cout << "g Background Tree (Training): " << g_background_train->GetEntries() << std::endl;
63 
64  // Create the TMVA tools
65  TMVA::Tools::Instance();
66 
67  auto outputFile = TFile::Open("StrangeJetClassification_Results.root", "RECREATE");
68 
69  TMVA::Factory factory("TMVAClassification", outputFile,
70  "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification" );
71 
72 
73  TMVA::DataLoader loader("dataset");
74 
75  loader.AddVariable("jet_pt");
76  loader.AddVariable("jet_eta");
77  loader.AddVariable("jet_nconstituents");
78  loader.AddVariable("jet_Ks_leading_zhadron");
79  loader.AddVariable("jet_K_leading_zhadron");
80  loader.AddVariable("jet_charge");
81  loader.AddVariable("jet_Ks_sumpt");
82  loader.AddVariable("jet_K_sumpt");
83  loader.AddVariable("jet_ehadoveremratio");
84  loader.AddSpectator("jet_flavor");
85 
86  loader.AddSignalTree( signal_train, 1.0 );
87  loader.AddBackgroundTree( background_train, 1.0 );
88  // loader.AddTree( signal_train, "strange_jets" );
89  // loader.AddTree( u_background_train, "up jets" );
90  // loader.AddTree( d_background_train, "down jets" );
91  // loader.AddTree( g_background_train, "gluon jets" );
92 
93  loader.PrepareTrainingAndTestTree(TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==3"),
94  TCut("jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<3||jet_flavor==21)"),
95  "nTrain_Signal=50000:nTrain_Background=100000:nTest_Signal=50000:nTest_Background=100000:SplitMode=Random:NormMode=NumEvents:!V" );
96 
97  // Declare the classification method(s)
98  factory.BookMethod(&loader,TMVA::Types::kBDT, "BDT",
99  "!V:NTrees=1000:MinNodeSize=2.5%:MaxDepth=4:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" );
100 
101  // Train
102  factory.TrainAllMethods();
103 
104  // Test
105  factory.TestAllMethods();
106  factory.EvaluateAllMethods();
107 
108  // Plot a ROC Curve
109  pad->cd();
110  pad = factory.GetROCCurve(&loader);
111  pad->Draw();
112 
113  pad->SaveAs("StrangeJetClassification_ROC.pdf");
114 
115 
116 }
117