4 #include "TEfficiency.h"
6 #include "TGraphErrors.h"
14 #include "TRatioPlot.h"
16 #include "TMVA/Factory.h"
17 #include "TMVA/DataLoader.h"
18 #include "TMVA/Tools.h"
31 gStyle->SetOptStat(0);
34 TCanvas* pad =
new TCanvas(
"pad",
"",800,600);
35 TLegend * legend =
nullptr;
36 TH1F* htemplate =
nullptr;
38 auto default_data =
new TChain(
"tree");
39 default_data->SetTitle(input.Data());
40 auto files =
fileVector(Form(
"%s/%s/%s", dir.Data(), input.Data(), filePattern.Data()));
44 default_data->Add(
file.c_str());
49 auto signal_train = default_data->CopyTree(
"jet_flavor==3",
"", TMath::Floor(default_data->GetEntries()/1.0));
50 std::cout <<
"Signal Tree (Training): " << signal_train->GetEntries() << std::endl;
52 auto background_train = default_data->CopyTree(
"(jet_flavor<3||jet_flavor==21)",
"", TMath::Floor(default_data->GetEntries()/1.0));
53 std::cout <<
"Background Tree (Training): " << background_train->GetEntries() << std::endl;
65 TMVA::Tools::Instance();
67 auto outputFile = TFile::Open(
"StrangeJetClassification_Results.root",
"RECREATE");
69 TMVA::Factory factory(
"TMVAClassification", outputFile,
70 "!V:ROC:!Correlations:!Silent:Color:DrawProgressBar:AnalysisType=Classification" );
73 TMVA::DataLoader
loader(
"dataset");
75 loader.AddVariable(
"jet_pt");
76 loader.AddVariable(
"jet_eta");
77 loader.AddVariable(
"jet_nconstituents");
78 loader.AddVariable(
"jet_Ks_leading_zhadron");
79 loader.AddVariable(
"jet_K_leading_zhadron");
80 loader.AddVariable(
"jet_charge");
81 loader.AddVariable(
"jet_Ks_sumpt");
82 loader.AddVariable(
"jet_K_sumpt");
83 loader.AddVariable(
"jet_ehadoveremratio");
84 loader.AddSpectator(
"jet_flavor");
86 loader.AddSignalTree( signal_train, 1.0 );
87 loader.AddBackgroundTree( background_train, 1.0 );
93 loader.PrepareTrainingAndTestTree(TCut(
"jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && jet_flavor==3"),
94 TCut(
"jet_pt>5.0 && TMath::Abs(jet_eta) < 3.0 && (jet_flavor<3||jet_flavor==21)"),
95 "nTrain_Signal=50000:nTrain_Background=100000:nTest_Signal=50000:nTest_Background=100000:SplitMode=Random:NormMode=NumEvents:!V" );
98 factory.BookMethod(&loader,TMVA::Types::kBDT,
"BDT",
99 "!V:NTrees=1000:MinNodeSize=2.5%:MaxDepth=4:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" );
102 factory.TrainAllMethods();
105 factory.TestAllMethods();
106 factory.EvaluateAllMethods();
110 pad = factory.GetROCCurve(&loader);
113 pad->SaveAs(
"StrangeJetClassification_ROC.pdf");