EIC Software
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
CbmRichTrainAnnElectrons.cxx
Go to the documentation of this file. Or view the newest version in sPHENIX GitHub for file CbmRichTrainAnnElectrons.cxx
1 
9 
10 #include "FairTrackParam.h"
11 #include "CbmGlobalTrack.h"
12 #include "CbmRichRing.h"
13 #include "CbmTrackMatch.h"
14 #include "CbmTrackMatch.h"
15 #include "FairRootManager.h"
16 #include "FairMCPoint.h"
17 #include "CbmMCTrack.h"
18 #include "CbmDrawHist.h"
19 
20 #include "TString.h"
21 #include "TSystem.h"
22 #include "TCanvas.h"
23 #include "TH1D.h"
24 #include "TH2D.h"
25 #include "TClonesArray.h"
26 #include "TMultiLayerPerceptron.h"
27 
28 #include <boost/assign/list_of.hpp>
29 
30 #include <iostream>
31 #include <vector>
32 #include <cmath>
33 #include <string>
34 
35 class CbmStsTrack;
36 
37 using std::cout;
38 using std::endl;
39 using std::vector;
40 using std::fabs;
41 using std::string;
42 using boost::assign::list_of;
43 
45  fEventNum(0),
46 
47  fRichHits(NULL),
48  fRichRings(NULL),
49  fRichPoints(NULL),
50  fMCTracks(NULL),
51  fRichRingMatches(NULL),
52  fRichProj(NULL),
53  fStsTrackMatches(NULL),
54  fGlobalTracks(NULL),
55  fStsTracks(NULL),
56 
57  fMinNofHitsInRichRing(7),
58  fQuota(0.6),
59  fMaxNofTrainSamples(5000),
60 
61  fNofPiLikeEl(0),
62  fNofElLikePi(0),
63  fAnnCut(-0.5),
64 
65  fhAnnOutput(),
66  fhCumProb(),
67 
68  fRElIdParams(),
69 
70  fhAaxis(),
71  fhBaxis(),
72  // fhAaxisCor(),
73  // fhBaxisCor(),
74  fhDistTrueMatch(),
75  fhDistMisMatch(),
76  fhNofHits(),
77  fhChi2(),
78  fhRadPos(),
79  fhAaxisVsMom(),
80  fhBaxisVsMom(),
81  fhPhiVsRadAng(),
82 
83  fHists()
84 {
85  fhAaxis.resize(2);
86  fhBaxis.resize(2);
87  // fhAaxisCor.resize(2);
88  // fhBaxisCor.resize(2);
89  fhDistTrueMatch.resize(2);
90  fhDistMisMatch.resize(2);
91  fhNofHits.resize(2);
92  fhChi2.resize(2);
93  fhRadPos.resize(2);
94  fhAaxisVsMom.resize(2);
95  fhBaxisVsMom.resize(2);
96  fhPhiVsRadAng.resize(2);
97  fhAnnOutput.resize(2);
98  fhCumProb.resize(2);
99  fRElIdParams.resize(2);
100  string ss;
101  for (int i = 0; i < 2; i++){
102  if (i == 0) ss = "Electron";
103  if (i == 1) ss = "Pion";
104  // difference between electrons and pions
105  fhAaxis[i] = new TH1D(string("fhAaxis"+ss).c_str(), "fhAaxis;A axis [cm];Counter", 50, 0., 8.);
106  fHists.push_back(fhAaxis[i]);
107  fhBaxis[i] = new TH1D(string("fhBAxis"+ss).c_str(), "fhBAxis;B axis [cm];Counter", 50, 0., 8.);
108  fHists.push_back(fhBaxis[i]);
109  // fhAaxisCor[i] = new TH1D(string("fhAaxisCor"+ss).c_str(), "fhAaxisCor;A axis [cm];Counter", 30,3,8);
110  // fHists.push_back(fhAaxisCor[i]);
111  // fhBaxisCor[i] = new TH1D(string("fhBAxisCor"+ss).c_str(), "fhBAxisCor;B axis [cm];Counter", 30,3,8);
112  // fHists.push_back(fhBaxisCor[i]);
113  fhDistTrueMatch[i] = new TH1D(string("fhDistTrueMatch"+ss).c_str(), "fhDistTrueMatch;Ring-track distance [cm];Counter", 50, 0., 5.);
114  fHists.push_back(fhDistTrueMatch[i]);
115  fhDistMisMatch[i] = new TH1D(string("fhDistMisMatch"+ss).c_str(), "fhDistMisMatch;Ring-track distance [cm];Counter", 50, 0., 5.);
116  fHists.push_back(fhDistMisMatch[i]);
117  fhNofHits[i] = new TH1D(string("fhNofHits"+ss).c_str(), "fhNofHits;Number of hits;Counter", 50, 0, 50);
118  fHists.push_back(fhNofHits[i]);
119  fhChi2[i] = new TH1D(string("fhChi2"+ss).c_str(), "fhChi2;#Chi^{2};Counter", 100, 0., 1.);
120  fHists.push_back(fhChi2[i]);
121  fhRadPos[i] = new TH1D(string("fhRadPos"+ss).c_str(), "fhRadPos;Radial position [cm];Counter", 150, 0., 150.);
122  fHists.push_back(fhRadPos[i]);
123  fhAaxisVsMom[i] = new TH2D(string("fhAaxisVsMom"+ss).c_str(), "fhAaxisVsMom;P [GeV/c];A axis [cm]",30, 0, 15, 50, 0, 10);
124  fHists.push_back(fhAaxisVsMom[i]);
125  fhBaxisVsMom[i] = new TH2D(string("fhBAxisVsMom"+ss).c_str(), "fhBAxisVsMom;P [GeV/c];B axis [cm]",30, 0, 15, 50, 0, 10);
126  fHists.push_back(fhBaxisVsMom[i]);
127  fhPhiVsRadAng[i] = new TH2D(string("fhPhiVsRadAng"+ss).c_str(), "fhPhiVsRadAng;Phi [rad];Radial angle [rad]", 50, -2. ,2.,50, 0. , 6.3);
128  fHists.push_back(fhPhiVsRadAng[i]);
129  // ANN outputs
130  fhAnnOutput[i] = new TH1D(string("fhAnnOutput"+ss).c_str(),"ANN output;ANN output;Counter",100, -1.2, 1.2);
131  fHists.push_back(fhAnnOutput[i]);
132  fhCumProb[i] = new TH1D(string("fhCumProb"+ss).c_str(),"ANN output;ANN output;Cumulative probability",100, -1.2, 1.2);
133  fHists.push_back(fhCumProb[i]);
134  }
135 }
136 
138 {
139 
140 }
141 
143 {
144  cout << "InitStatus CbmRichTrainAnnElectrons::Init()"<<endl;
145 
147  if (NULL == ioman) { Fatal("CbmRichTrainAnnElectrons::Init","RootManager not instantised!");}
148 
149  fRichHits = (TClonesArray*) ioman->GetObject("RichHit");
150  if ( NULL == fRichHits) { Fatal("CbmRichTrainAnnElectrons::Init","No RichHit array!");}
151 
152  fRichRings = (TClonesArray*) ioman->GetObject("RichRing");
153  if ( NULL == fRichRings) { Fatal("CbmRichTrainAnnElectrons::Init","No RichRing array!");}
154 
155  fRichPoints = (TClonesArray*) ioman->GetObject("RichPoint");
156  if ( NULL == fRichPoints) { Fatal("CbmRichTrainAnnElectrons::Init","No RichPoint array!");}
157 
158  fMCTracks = (TClonesArray*) ioman->GetObject("MCTrack");
159  if ( NULL == fMCTracks) { Fatal("CbmRichTrainAnnElectrons::Init","No MCTrack array!");}
160 
161  fRichRingMatches = (TClonesArray*) ioman->GetObject("RichRingMatch");
162  if ( NULL == fRichRingMatches) { Fatal("CbmRichTrainAnnElectrons::Init","No RichRingMatch array!");}
163 
164  fRichProj = (TClonesArray*) ioman->GetObject("RichProjection");
165  if ( NULL == fRichProj) { Fatal("CbmRichTrainAnnElectrons::Init","No RichProjection array!");}
166 
167  fStsTrackMatches = (TClonesArray*) ioman->GetObject("StsTrackMatch");
168  if ( NULL == fStsTrackMatches) { Fatal("CbmRichTrainAnnElectrons::Init","No track match array!");}
169 
170  fGlobalTracks = (TClonesArray*) ioman->GetObject("GlobalTrack");
171  if ( NULL == fGlobalTracks) { Fatal("CbmRichTrainAnnElectrons::Init","No global track array!");}
172 
173  fStsTracks = (TClonesArray*) ioman->GetObject("StsTrack");
174  if ( NULL == fStsTracks) { Fatal("CbmRichTrainAnnElectrons::Init","No STSTrack array!");}
175 
176  return kSUCCESS;
177 }
178 
180  Option_t* option)
181 {
182  cout << endl <<"-I- CbmRichTrainAnnElectrons, event " << fEventNum << endl;
183  DiffElandPi();
184  fEventNum++;
185  cout <<"Nof Electrons = " << fRElIdParams[0].size() << endl;
186  cout <<"Nof Pions = " << fRElIdParams[1].size() << endl;
187 }
188 
190 {
191  Int_t nGlTracks = fGlobalTracks->GetEntriesFast();
192 
193  for (Int_t iTrack=0; iTrack < nGlTracks; iTrack++) {
194  CbmGlobalTrack* gTrack = (CbmGlobalTrack*)fGlobalTracks->At(iTrack);
195  if (NULL == gTrack) continue;
196  Int_t stsIndex = gTrack->GetStsTrackIndex();
197  if (stsIndex == -1) continue;
198  CbmStsTrack* stsTrack = (CbmStsTrack*)fStsTracks->At(stsIndex);
199  if (NULL == stsTrack) continue;
200  CbmTrackMatch* stsTrackMatch = (CbmTrackMatch*)fStsTrackMatches->At(stsIndex);
201  if (NULL == stsTrackMatch) continue;
202  Int_t mcIdSts = stsTrackMatch->GetMCTrackId();
203 
204  Int_t richIndex = gTrack->GetRichRingIndex();
205  if (richIndex == -1) continue;
206  CbmRichRing* ring = (CbmRichRing*)fRichRings->At(richIndex);
207  if (NULL == ring) continue;
208  CbmTrackMatch* richRingMatch = (CbmTrackMatch*)fRichRingMatches->At(richIndex);
209  if (NULL == richRingMatch) continue;
210  Int_t mcIdRich = richRingMatch->GetMCTrackId();
211 
212  CbmMCTrack* track = (CbmMCTrack*) fMCTracks->At(mcIdSts);
213  if (NULL == track) continue;
214  Int_t pdg = TMath::Abs(track->GetPdgCode());
215  Int_t motherId = track->GetMotherId();
216  Double_t momentum = track->GetP();
217 
218  Double_t axisACor = ring->GetAaxisCor();
219  Double_t axisBCor= ring->GetBaxisCor();
220 
221  Int_t lFoundHits = richRingMatch->GetNofTrueHits() + richRingMatch->GetNofWrongHits()
222  + richRingMatch->GetNofFakeHits();
223  Double_t lPercTrue = 0;
224  if (lFoundHits >= 3){
225  lPercTrue = (Double_t)richRingMatch->GetNofTrueHits() / (Double_t)lFoundHits;
226  }
227  Bool_t isTrueFound = (lPercTrue > fQuota);
228 
230  p.fAaxis = ring->GetAaxis();
231  p.fBaxis = ring->GetBaxis();
232  p.fPhi = ring->GetPhi();
233  p.fRadAngle = ring->GetRadialAngle();
234  p.fChi2 = ring->GetChi2()/ring->GetNDF();
235  p.fRadPos = ring->GetRadialPosition();
236  p.fNofHits = ring->GetNofHits();
237  p.fDistance = ring->GetDistance();
238  p.fMomentum = momentum;
239 
240  // electrons
241  if (pdg == 11 && motherId == -1 && isTrueFound &&
242  mcIdSts == mcIdRich && mcIdRich != -1){
243  fhAaxis[0]->Fill(p.fAaxis);
244  fhBaxis[0]->Fill(p.fBaxis);
245  // fhAaxisCor[0]->Fill(axisACor);
246  // fhBaxisCor[0]->Fill(axisBCor);
247  fhDistTrueMatch[0]->Fill(p.fDistance);
248  fhNofHits[0]->Fill(p.fNofHits);
249  fhChi2[0]->Fill(p.fChi2);
250  fhRadPos[0]->Fill(p.fRadPos);
251  fhAaxisVsMom[0]->Fill(momentum, p.fAaxis);
252  fhBaxisVsMom[0]->Fill(momentum, p.fBaxis);
253  fhPhiVsRadAng[0]->Fill(p.fPhi, p.fRadAngle);
254  fRElIdParams[0].push_back(p);
255  }
256 
257  if (pdg == 11 && motherId == -1 && isTrueFound &&
258  mcIdSts != mcIdRich && mcIdRich != -1){
259  fhDistMisMatch[0]->Fill(p.fDistance);
260  }
261 
262 
263  // pions
264  if ( pdg == 211 && mcIdRich != -1){
265  fhAaxis[1]->Fill(p.fAaxis);
266  fhBaxis[1]->Fill(p.fBaxis);
267  // fhAaxisCor[1]->Fill(axisACor);
268  // fhBaxisCor[1]->Fill(axisBCor);
269  if (mcIdRich == mcIdSts) {
270  fhDistTrueMatch[1]->Fill(p.fDistance);
271  fhAaxisVsMom[1]->Fill(momentum, p.fAaxis);
272  fhBaxisVsMom[1]->Fill(momentum, p.fBaxis);
273  } else {
274  fhDistMisMatch[1]->Fill(p.fDistance);
275  }
276  fhNofHits[1]->Fill(p.fNofHits);
277  fhChi2[1]->Fill(p.fChi2);
278  fhRadPos[1]->Fill(p.fRadPos);
279  fhPhiVsRadAng[1]->Fill(p.fPhi, p.fRadAngle);
280 
281  fRElIdParams[1].push_back(p);
282  }
283  }// global tracks
284 }
285 
287 {
288  TTree *simu = new TTree ("MonteCarlo","MontecarloData");
289  Double_t x[9];
290  Double_t xOut;
291 
292  simu->Branch("x0", &x[0],"x0/D");
293  simu->Branch("x1", &x[1],"x1/D");
294  simu->Branch("x2", &x[2],"x2/D");
295  simu->Branch("x3", &x[3],"x3/D");
296  simu->Branch("x4", &x[4],"x4/D");
297  simu->Branch("x5", &x[5],"x5/D");
298  simu->Branch("x6", &x[6],"x6/D");
299  simu->Branch("x7", &x[7],"x7/D");
300  simu->Branch("x8", &x[8],"x8/D");
301  simu->Branch("xOut", &xOut,"xOut/D");
302 
303  for (int j = 0; j < 2; j++){
304  for (int i = 0; i < fRElIdParams[j].size(); i++){
305  x[0] = fRElIdParams[j][i].fAaxis / 10.;
306  x[1] = fRElIdParams[j][i].fBaxis / 10.;
307  x[2] = (fRElIdParams[j][i].fPhi + 1.57) / 3.14;
308  x[3] = fRElIdParams[j][i].fRadAngle / 6.28;
309  x[4] = fRElIdParams[j][i].fChi2 / 1.2;
310  x[5] = fRElIdParams[j][i].fRadPos / 110.;
311  x[6] = fRElIdParams[j][i].fNofHits / 45.;
312  x[7] = fRElIdParams[j][i].fDistance / 5.;
313  x[8] = fRElIdParams[j][i].fMomentum / 15.;
314 
315  for (int k = 0; k < 9; k++){
316  if (x[k] < 0.0) x[k] = 0.0;
317  if (x[k] > 1.0) x[k] = 1.0;
318  }
319 
320  if (j == 0) xOut = 1.;
321  if (j == 1) xOut = -1.;
322  simu->Fill();
323  if (i >= fMaxNofTrainSamples) break;
324  }
325  }
326 
327  TMultiLayerPerceptron network("x0,x1,x2,x3,x4,x5,x6,x7,x8:18:xOut",simu,"Entry$+1");
328  //network.LoadWeights("");
329  network.Train(300,"text,update=10");
330  network.DumpWeights("rich_elid_ann_weights.txt");
331  //network.Export();
332 
333  Double_t params[9];
334 
335  fNofPiLikeEl = 0;
336  fNofElLikePi = 0;
337 
338  for (int j = 0; j < 2; j++){
339  for (int i = 0; i < fRElIdParams[j].size(); i++){
340  params[0] = fRElIdParams[j][i].fAaxis / 10.;
341  params[1] = fRElIdParams[j][i].fBaxis / 10.;
342  params[2] = (fRElIdParams[j][i].fPhi + 1.57) / 3.14;
343  params[3] = fRElIdParams[j][i].fRadAngle / 6.28;
344  params[4] = fRElIdParams[j][i].fChi2 / 1.2;
345  params[5] = fRElIdParams[j][i].fRadPos / 110.;
346  params[6] = fRElIdParams[j][i].fNofHits / 45.;
347  params[7] = fRElIdParams[j][i].fDistance / 5.;
348  params[8] = fRElIdParams[j][i].fMomentum / 15.;
349 
350  for (int k = 0; k < 9; k++){
351  if (params[k] < 0.0) params[k] = 0.0;
352  if (params[k] > 1.0) params[k] = 1.0;
353  }
354 
355  Double_t netEval = network.Evaluate(0,params);
356 
357  //if (netEval > maxEval) netEval = maxEval - 0.01;
358  // if (netEval < minEval) netEval = minEval + 0.01;
359 
360  fhAnnOutput[j]->Fill(netEval);
361  if (netEval >= fAnnCut && j == 1) fNofPiLikeEl++;
362  if (netEval < fAnnCut && j == 0) fNofElLikePi++;
363  }
364  }
365 }
366 
368 {
369  cout <<"nof electrons = " << fRElIdParams[0].size() << endl;
370  cout <<"nof pions = " << fRElIdParams[1].size() << endl;
371  cout <<"Pions like electrons = " << fNofPiLikeEl << ", pi supp = " << (Double_t) fRElIdParams[1].size() / fNofPiLikeEl << endl;
372  cout <<"Electrons like pions = " << fNofElLikePi << ", el lost eff = " << 100.* (Double_t)fNofElLikePi / fRElIdParams[0].size()<< endl;
373  cout <<"ANN cut = " << fAnnCut << endl;
374 
375  Double_t cumProbFake = 0.;
376  Double_t cumProbTrue = 0.;
377  Int_t nofFake = (Int_t)fhAnnOutput[1]->GetEntries();
378  Int_t nofTrue = (Int_t)fhAnnOutput[0]->GetEntries();
379  for (Int_t i = 1; i <= fhAnnOutput[1]->GetNbinsX(); i++){
380  cumProbFake += fhAnnOutput[1]->GetBinContent(i);
381  fhCumProb[1]->SetBinContent(i, (Double_t)cumProbFake/nofFake);
382 
383  cumProbTrue += fhAnnOutput[0]->GetBinContent(i);
384  fhCumProb[0]->SetBinContent(i, 1. - (Double_t)cumProbTrue / nofTrue);
385  }
386 
388  TCanvas* c1 = new TCanvas("ann_electrons_ann_output", "ann_electrons_ann_output", 500, 500);
389  DrawH1(list_of(fhAnnOutput[0])(fhAnnOutput[1]), list_of("e^{#pm}")("#pi^{#pm}"),
390  kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
391 
392  TCanvas* c2 = new TCanvas("ann_electrons_cum_prob", "ann_electrons_cum_prob", 500, 500);
393  DrawH1(list_of(fhCumProb[0])(fhCumProb[1]), list_of("e^{#pm}")("#pi^{#pm}"),
394  kLinear, kLinear, true, 0.8, 0.8, 0.99, 0.99);
395 
396  int c = 1;
397  TCanvas* c3 = new TCanvas("ann_electrons_params_ab", "ann_electrons_params_ab", 1200, 600);
398  c3->Divide(2, 1);
399  c3->cd(c++);
400  DrawH1(list_of(fhAaxis[0])(fhAaxis[1]), list_of("e^{#pm}")("#pi^{#pm}"), kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
401  c3->cd(c++);
402  DrawH1(list_of(fhBaxis[0])(fhBaxis[1]), list_of("e^{#pm}")("#pi^{#pm}"), kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
403 // c3->cd(c++);
404 // DrawH1(list_of(fhAaxisCor[0])(fhAaxisCor[1]), list_of("e^{#pm}")("#pi^{#pm}"), kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
405 // c3->cd(c++);
406 // DrawH1(list_of(fhBaxisCor[0])(fhBaxisCor[1]), list_of("e^{#pm}")("#pi^{#pm}"), kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
407 
408  c = 1;
409  TCanvas* c3_2 = new TCanvas("ann_electrons_params_1", "ann_electrons_params_1", 1500, 600);
410  c3_2->Divide(3, 1);
411  c3_2->cd(c++);
412  //fhAaxisVsMom[0]->SetLineColor(kRed);
413  //fhAaxisVsMom[1]->SetLineColor(kBlue);
415  //DrawH2(fhAaxisVsMom[0], kLinear, kLinear, kLinear, "same");
416  c3_2->cd(c++);
417  //fhBaxisVsMom[0]->SetLineColor(kRed);
418  //fhBaxisVsMom[1]->SetLineColor(kBlue);
419  //DrawH2(fhBaxisVsMom[1], kLinear, kLinear, kLinear);
421 
422  c3_2->cd(c++);
424  list_of("e^{#pm} true match")("e^{#pm} mis match")("#pi^{#pm} true match")("#pi^{#pm} mis match"),
425  kLinear, kLog, true, 0.7, 0.7, 0.99, 0.99);
426 
427  c = 1;
428  TCanvas* c3_1 = new TCanvas("ann_electrons_params_2", "ann_electrons_params_2", 1500, 600);
429  c3_1->Divide(3, 1);
430  c3_1->cd(c++);
431  DrawH1(list_of(fhNofHits[0])(fhNofHits[1]), list_of("e^{#pm}")("#pi^{#pm}"), kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
432  c3_1->cd(c++);
433  DrawH1(list_of(fhChi2[0])(fhChi2[1]), list_of("e^{#pm}")("#pi^{#pm}"), kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
434  c3_1->cd(c++);
435  DrawH1(list_of(fhRadPos[0])(fhRadPos[1]), list_of("e^{#pm}")("#pi^{#pm}"), kLinear, kLog, true, 0.8, 0.8, 0.99, 0.99);
436 
437  c = 1;
438  TCanvas* c4 = new TCanvas("ann_electrons_params_2d", "ann_electrons_params_2d", 600, 900);
439  c4->Divide(2, 3);
440  c4->cd(c++);
441  DrawH2(fhAaxisVsMom[0]);
442  c4->cd(c++);
443  DrawH2(fhAaxisVsMom[1]);
444  c4->cd(c++);
445  DrawH2(fhBaxisVsMom[0]);
446  c4->cd(c++);
447  DrawH2(fhBaxisVsMom[1]);
448  c4->cd(c++);
449  DrawH2(fhPhiVsRadAng[0]);
450  c4->cd(c++);
451  DrawH2(fhPhiVsRadAng[1]);
452 
453 // TCanvas* c5 = new TCanvas("ann_electrons_b_vs_mom", "ann_electrons_b_vs_mom", 600, 600);
454 // fhBaxisVsMom[0]->Add(fhBaxisVsMom[1]);
455 // DrawH2(fhBaxisVsMom[0]);
456 //
457 // TCanvas* c6 = new TCanvas("ann_electrons_a_vs_mom", "ann_electrons_a_vs_mom", 600, 600);
458 // fhAaxisVsMom[0]->Add(fhAaxisVsMom[1]);
459 // DrawH2(fhAaxisVsMom[0]);
460 
461 }
462 
464 {
465  TrainAndTestAnn();
466  Draw();
467 
468  for (int i = 0; i < fHists.size(); i++){
469  fHists[i]->Scale(1./fHists[i]->Integral());
470  fHists[i]->Write();
471  }
472 }
473