26 #include <CL/sycl.hpp>
28 namespace Acts::Sycl {
30 class duplet_search_bottom_kernel;
31 class duplet_search_top_kernel;
32 class ind_copy_bottom_kernel;
33 class ind_copy_top_kernel;
34 class transform_coord_bottom_kernel;
35 class transform_coord_top_kernel;
36 class triplet_search_kernel;
37 class filter_2sp_fixed_kernel;
43 const std::vector<detail::DeviceSpacePoint>& bottomSPs,
44 const std::vector<detail::DeviceSpacePoint>& middleSPs,
45 const std::vector<detail::DeviceSpacePoint>& topSPs,
46 std::vector<std::vector<detail::SeedData>>& seeds) {
59 std::vector<uint32_t> sumBotCompUptoMid(M + 1, 0);
60 std::vector<uint32_t> sumTopCompUptoMid(M + 1, 0);
61 std::vector<uint32_t> sumBotTopCombined(M + 1, 0);
67 std::vector<uint32_t> indMidBotComp;
68 std::vector<uint32_t> indMidTopComp;
71 uint64_t edgesBottom = 0;
72 uint64_t edgesTop = 0;
77 uint64_t edgesComb = 0;
81 uint64_t globalBufferSize =
82 q->get_device().get_info<cl::sycl::info::device::global_mem_size>();
83 uint64_t maxWorkGroupSize =
84 q->get_device().get_info<cl::sycl::info::device::max_work_group_size>();
88 cl::sycl::malloc_device<detail::DeviceSpacePoint>(B, *q);
90 cl::sycl::malloc_device<detail::DeviceSpacePoint>(
M, *q);
92 cl::sycl::malloc_device<detail::DeviceSpacePoint>(
T, *q);
99 uint32_t* deviceTmpIndBot = cl::sycl::malloc_device<uint32_t>(M * B, *q);
100 uint32_t* deviceTmpIndTop = cl::sycl::malloc_device<uint32_t>(M *
T, *q);
102 q->memcpy(deviceBottomSPs, bottomSPs.data(),
104 q->memcpy(deviceMiddleSPs, middleSPs.data(),
106 q->memcpy(deviceTopSPs, topSPs.data(),
114 cl::sycl::nd_range<2> bottomDupletNDRange =
116 cl::sycl::nd_range<2> topDupletNDRange =
120 std::vector<uint32_t> countBotDuplets(M, 0);
121 std::vector<uint32_t> countTopDuplets(M, 0);
128 sycl::buffer<uint32_t> countBotBuf(countBotDuplets.data(),
M);
129 sycl::buffer<uint32_t> countTopBuf(countTopDuplets.data(),
M);
131 auto countBotDupletsAcc =
132 sycl::ONEAPI::atomic_accessor<
uint32_t, 1,
133 sycl::ONEAPI::memory_order::relaxed,
134 sycl::ONEAPI::memory_scope::device>(
136 h.parallel_for<duplet_search_bottom_kernel>(
137 bottomDupletNDRange, [=](cl::sycl::nd_item<2> item) {
138 const auto mid = item.get_global_id(0);
139 const auto bot = item.get_global_id(1);
144 if (
mid < M && bot < B) {
145 const auto midSP = deviceMiddleSPs[
mid];
146 const auto botSP = deviceBottomSPs[bot];
148 const auto deltaR = midSP.r - botSP.r;
149 const auto cotTheta = (midSP.z - botSP.z) /
deltaR;
150 const auto zOrigin = midSP.z - midSP.r * cotTheta;
158 const auto ind = countBotDupletsAcc[
mid].fetch_add(1);
159 deviceTmpIndBot[
mid * B + ind] = bot;
166 auto countTopDupletsAcc =
167 sycl::ONEAPI::atomic_accessor<
uint32_t, 1,
168 sycl::ONEAPI::memory_order::relaxed,
169 sycl::ONEAPI::memory_scope::device>(
171 h.parallel_for<duplet_search_top_kernel>(
172 topDupletNDRange, [=](cl::sycl::nd_item<2> item) {
173 const auto mid = item.get_global_id(0);
174 const auto top = item.get_global_id(1);
177 if (
mid < M && top < T) {
178 const auto midSP = deviceMiddleSPs[
mid];
179 const auto topSP = deviceTopSPs[top];
181 const auto deltaR = topSP.r - midSP.r;
182 const auto cotTheta = (topSP.z - midSP.z) /
deltaR;
183 const auto zOrigin = midSP.z - midSP.r * cotTheta;
191 const auto ind = countTopDupletsAcc[
mid].fetch_add(1);
192 deviceTmpIndTop[
mid * T + ind] = top;
209 for (
uint32_t i = 1; i < M + 1; ++i) {
210 sumBotCompUptoMid[i] +=
211 sumBotCompUptoMid[i - 1] + countBotDuplets[i - 1];
212 sumTopCompUptoMid[i] +=
213 sumTopCompUptoMid[i - 1] + countTopDuplets[i - 1];
214 sumBotTopCombined[i] += sumBotTopCombined[i - 1] +
215 countTopDuplets[i - 1] * countBotDuplets[i - 1];
218 edgesBottom = sumBotCompUptoMid[
M];
219 edgesTop = sumTopCompUptoMid[
M];
220 edgesComb = sumBotTopCombined[
M];
222 indMidBotComp.reserve(edgesBottom);
223 indMidTopComp.reserve(edgesTop);
227 std::fill_n(std::back_inserter(indMidBotComp), countBotDuplets[
mid],
229 std::fill_n(std::back_inserter(indMidTopComp), countTopDuplets[mid],
234 if (edgesBottom > 0 && edgesTop > 0) {
237 cl::sycl::nd_range<1> edgesBotNdRange =
241 cl::sycl::nd_range<1> edgesTopNdRange =
310 sycl::buffer<uint32_t> numTopDupletsBuf(countTopDuplets.data(),
M);
316 cl::sycl::malloc_device<uint32_t>(edgesBottom, *q);
317 uint32_t* deviceIndTop = cl::sycl::malloc_device<uint32_t>(edgesTop, *q);
323 cl::sycl::malloc_device<uint32_t>(edgesBottom, *q);
325 cl::sycl::malloc_device<uint32_t>(edgesTop, *q);
328 uint32_t* deviceSumBot = cl::sycl::malloc_device<uint32_t>(M + 1, *q);
329 uint32_t* deviceSumTop = cl::sycl::malloc_device<uint32_t>(M + 1, *q);
333 uint32_t* deviceSumComb = cl::sycl::malloc_device<uint32_t>(M + 1, *q);
337 cl::sycl::malloc_device<detail::DeviceLinEqCircle>(edgesBottom, *q);
339 cl::sycl::malloc_device<detail::DeviceLinEqCircle>(edgesTop, *q);
341 q->memcpy(deviceMidIndPerBot, indMidBotComp.data(),
343 q->memcpy(deviceMidIndPerTop, indMidTopComp.data(),
345 q->memcpy(deviceSumBot, sumBotCompUptoMid.data(),
347 q->memcpy(deviceSumTop, sumTopCompUptoMid.data(),
349 q->memcpy(deviceSumComb, sumBotTopCombined.data(),
357 h.parallel_for<ind_copy_bottom_kernel>(
358 edgesBotNdRange, [=](cl::sycl::nd_item<1> item) {
359 auto idx = item.get_global_linear_id();
360 if (idx < edgesBottom) {
361 auto mid = deviceMidIndPerBot[idx];
362 auto ind = deviceTmpIndBot[
mid * B + idx - deviceSumBot[
mid]];
363 deviceIndBot[idx] = ind;
369 h.parallel_for<ind_copy_top_kernel>(
370 edgesTopNdRange, [=](cl::sycl::nd_item<1> item) {
371 auto idx = item.get_global_linear_id();
372 if (idx < edgesTop) {
373 auto mid = deviceMidIndPerTop[idx];
374 auto ind = deviceTmpIndTop[
mid * T + idx - deviceSumTop[
mid]];
375 deviceIndTop[idx] = ind;
392 h.parallel_for<transform_coord_bottom_kernel>(
393 edgesBotNdRange, [=](cl::sycl::nd_item<1> item) {
394 auto idx = item.get_global_linear_id();
395 if (idx < edgesBottom) {
396 const auto midSP = deviceMiddleSPs[deviceMidIndPerBot[idx]];
397 const auto botSP = deviceBottomSPs[deviceIndBot[idx]];
399 const auto xM = midSP.x;
400 const auto yM = midSP.y;
401 const auto zM = midSP.z;
402 const auto rM = midSP.r;
403 const auto varianceZM = midSP.varZ;
404 const auto varianceRM = midSP.varR;
405 const auto cosPhiM = xM / rM;
406 const auto sinPhiM = yM / rM;
408 const auto deltaX = botSP.x - xM;
409 const auto deltaY = botSP.y - yM;
410 const auto deltaZ = botSP.z - zM;
412 const auto x = deltaX * cosPhiM + deltaY * sinPhiM;
413 const auto y = deltaY * cosPhiM - deltaX * sinPhiM;
414 const auto iDeltaR2 = 1.f / (deltaX * deltaX + deltaY * deltaY);
415 const auto iDeltaR = cl::sycl::sqrt(iDeltaR2);
416 const auto cot_theta = -(deltaZ * iDeltaR);
420 L.
zo = zM - rM * cot_theta;
424 L.
er = ((varianceZM + botSP.varZ) +
425 (cot_theta * cot_theta) * (varianceRM + botSP.varR)) *
428 deviceLinBot[idx] =
L;
435 h.parallel_for<transform_coord_top_kernel>(
436 edgesTopNdRange, [=](cl::sycl::nd_item<1> item) {
437 auto idx = item.get_global_linear_id();
438 if (idx < edgesTop) {
439 const auto midSP = deviceMiddleSPs[deviceMidIndPerTop[idx]];
440 const auto topSP = deviceTopSPs[deviceIndTop[idx]];
442 const auto xM = midSP.x;
443 const auto yM = midSP.y;
444 const auto zM = midSP.z;
445 const auto rM = midSP.r;
446 const auto varianceZM = midSP.varZ;
447 const auto varianceRM = midSP.varR;
448 const auto cosPhiM = xM / rM;
449 const auto sinPhiM = yM / rM;
451 const auto deltaX = topSP.x - xM;
452 const auto deltaY = topSP.y - yM;
453 const auto deltaZ = topSP.z - zM;
455 const auto x = deltaX * cosPhiM + deltaY * sinPhiM;
456 const auto y = deltaY * cosPhiM - deltaX * sinPhiM;
457 const auto iDeltaR2 = 1.f / (deltaX * deltaX + deltaY * deltaY);
458 const auto iDeltaR = cl::sycl::sqrt(iDeltaR2);
459 const auto cot_theta = deltaZ * iDeltaR;
463 L.
zo = zM - rM * cot_theta;
467 L.
er = ((varianceZM + topSP.varZ) +
468 (cot_theta * cot_theta) * (varianceRM + topSP.varR)) *
471 deviceLinTop[idx] =
L;
552 const auto maxMemoryAllocation =
559 cl::sycl::malloc_device<detail::DeviceTriplet>(maxMemoryAllocation,
566 cl::sycl::malloc_device<detail::SeedData>(maxMemoryAllocation, *q);
573 std::vector<uint32_t> deviceCountTriplets(edgesBottom, 0);
579 for (
uint32_t firstMiddle = 0; firstMiddle <
M;
580 firstMiddle = lastMiddle) {
583 while (lastMiddle + 1 <= M && (sumBotTopCombined[lastMiddle + 1] -
584 sumBotTopCombined[firstMiddle] <
585 maxMemoryAllocation)) {
589 const auto numTripletSearchThreads =
590 sumBotTopCombined[lastMiddle] - sumBotTopCombined[firstMiddle];
592 if (numTripletSearchThreads == 0)
596 deviceCountTriplets.resize(edgesBottom, 0);
598 const auto numTripletFilterThreads =
599 sumBotCompUptoMid[lastMiddle] - sumBotCompUptoMid[firstMiddle];
601 const auto sumCombUptoFirstMiddle = sumBotTopCombined[firstMiddle];
605 cl::sycl::nd_range<1> tripletSearchNDRange =
608 cl::sycl::nd_range<1> tripletFilterNDRange =
611 sycl::buffer<uint32_t> countTripletsBuf(deviceCountTriplets.data(),
615 h.depends_on({linB, linT});
616 auto countTripletsAcc =
617 sycl::ONEAPI::atomic_accessor<
uint32_t, 1,
618 sycl::ONEAPI::memory_order::relaxed,
619 sycl::ONEAPI::memory_scope::device>(
620 countTripletsBuf,
h);
621 auto numTopDupletsAcc = numTopDupletsBuf.get_access<
623 h.parallel_for<triplet_search_kernel>(
624 tripletSearchNDRange, [=](cl::sycl::nd_item<1> item) {
625 const uint32_t idx = item.get_global_linear_id();
626 if (idx < numTripletSearchThreads) {
629 auto L = firstMiddle;
636 if (idx + sumCombUptoFirstMiddle < deviceSumComb[
mid]) {
644 const auto numT = numTopDupletsAcc[
mid];
645 const auto threadIdxForMiddleSP =
646 (idx - deviceSumComb[
mid] + sumCombUptoFirstMiddle);
694 deviceSumBot[
mid] + (threadIdxForMiddleSP / numT);
696 deviceSumTop[
mid] + (threadIdxForMiddleSP % numT);
698 const auto linBotEq = deviceLinBot[ib];
699 const auto linTopEq = deviceLinTop[
it];
700 const auto midSP = deviceMiddleSPs[
mid];
702 const auto Vb = linBotEq.v;
703 const auto Ub = linBotEq.u;
704 const auto Erb = linBotEq.er;
705 const auto cotThetab = linBotEq.cotTheta;
706 const auto iDeltaRb = linBotEq.iDeltaR;
708 const auto Vt = linTopEq.v;
709 const auto Ut = linTopEq.u;
710 const auto Ert = linTopEq.er;
711 const auto cotThetat = linTopEq.cotTheta;
712 const auto iDeltaRt = linTopEq.iDeltaR;
714 const auto rM = midSP.r;
715 const auto varianceRM = midSP.varR;
716 const auto varianceZM = midSP.varZ;
718 auto iSinTheta2 = (1.f + cotThetab * cotThetab);
719 auto scatteringInRegion2 =
725 2.f * (cotThetab * cotThetat * varianceRM + varianceZM) *
727 auto deltaCotTheta = cotThetab - cotThetat;
728 auto deltaCotTheta2 = deltaCotTheta * deltaCotTheta;
731 auto error = cl::sycl::sqrt(error2);
732 auto dCotThetaMinusError2 =
733 deltaCotTheta2 + error2 - 2.f * deltaCotTheta * error;
736 if ((!(deltaCotTheta2 - error2 > 0.f) ||
737 !(dCotThetaMinusError2 > scatteringInRegion2)) &&
739 auto A = (Vt - Vb) / dU;
740 auto S2 = 1.f + A * A;
741 auto B = Vb - A * Ub;
744 auto iHelixDiameter2 = B2 / S2;
747 auto p2scatter = pT2scatter * iSinTheta2;
751 !((deltaCotTheta2 - error2 > 0.f) &&
752 (dCotThetaMinusError2 >
756 const auto top = deviceIndTop[
it];
759 auto t = countTripletsAcc[ib].fetch_add(1);
779 const auto tripletIdx = deviceSumComb[
mid] -
780 sumCombUptoFirstMiddle +
781 (((idx - deviceSumComb[
mid] +
782 sumCombUptoFirstMiddle) /
788 T.curvature = B / cl::sycl::sqrt(S2);
791 deviceCurvImpact[tripletIdx] =
T;
799 sycl::buffer<uint32_t> countSeedsBuf(&sumSeeds, 1);
801 h.depends_on(tripletKernel);
802 auto countSeedsAcc = sycl::ONEAPI::atomic_accessor<
803 uint32_t, 1, sycl::ONEAPI::memory_order::relaxed,
804 sycl::ONEAPI::memory_scope::device>(countSeedsBuf,
h);
805 auto countTripletsAcc = countTripletsBuf.get_access<
808 auto numTopDupletsAcc = numTopDupletsBuf.get_access<
811 h.parallel_for<filter_2sp_fixed_kernel>(
812 tripletFilterNDRange, [=](cl::sycl::nd_item<1> item) {
813 if (item.get_global_linear_id() < numTripletFilterThreads) {
815 deviceSumBot[firstMiddle] + item.get_global_linear_id();
816 const auto mid = deviceMidIndPerBot[idx];
817 const auto bot = deviceIndBot[idx];
819 const auto tripletBegin =
820 deviceSumComb[
mid] - sumCombUptoFirstMiddle +
821 (idx - deviceSumBot[
mid]) * numTopDupletsAcc[
mid];
822 const auto tripletEnd =
823 tripletBegin + countTripletsAcc[idx];
825 for (
auto i1 = tripletBegin; i1 < tripletEnd; ++i1) {
826 const auto current = deviceCurvImpact[i1];
829 const auto invHelixDiameter = current.curvature;
830 const auto lowerLimitCurv =
833 const auto upperLimitCurv =
836 const auto currentTop_r = deviceTopSPs[top].r;
837 auto weight = -(current.impact *
840 uint32_t compatCounter = 0;
845 float compatibleSeedR[2];
846 for (
auto i2 = tripletBegin;
850 const auto other = deviceCurvImpact[i2];
853 const auto otherTop_r =
854 deviceTopSPs[other.topSPIndex].r;
858 otherCurv >= lowerLimitCurv &&
859 otherCurv <= upperLimitCurv) {
867 if (c == compatCounter) {
868 compatibleSeedR[
c] = otherTop_r;
877 const auto bottomSP = deviceBottomSPs[bot];
878 const auto middleSP = deviceMiddleSPs[
mid];
879 const auto topSP = deviceTopSPs[top];
882 deviceCuts.
seedWeight(bottomSP, middleSP, topSP);
886 const auto i = countSeedsAcc[0].fetch_add(1);
892 deviceSeedArray[i] = D;
902 std::vector<detail::SeedData> hostSeedArray(sumSeeds);
903 auto e0 = q->memcpy(&hostSeedArray[0], deviceSeedArray,
908 auto m = hostSeedArray[
t].middle;
909 seeds[
m].push_back(hostSeedArray[
t]);
925 cl::sycl::free(deviceLinBot, *q);
926 cl::sycl::free(deviceLinTop, *q);
928 cl::sycl::free(deviceIndBot, *q);
929 cl::sycl::free(deviceIndTop, *q);
930 cl::sycl::free(deviceMidIndPerBot, *q);
931 cl::sycl::free(deviceMidIndPerTop, *q);
932 cl::sycl::free(deviceSumBot, *q);
933 cl::sycl::free(deviceSumTop, *q);
934 cl::sycl::free(deviceSumComb, *q);
936 cl::sycl::free(deviceCurvImpact, *q);
937 cl::sycl::free(deviceSeedArray, *q);
940 cl::sycl::free(deviceTmpIndBot, *q);
941 cl::sycl::free(deviceTmpIndTop, *q);
942 cl::sycl::free(deviceBottomSPs, *q);
943 cl::sycl::free(deviceMiddleSPs, *q);
944 cl::sycl::free(deviceTopSPs, *q);
945 }
catch (cl::sycl::exception
const&
e) {
948 ACTS_FATAL(
"Caught synchronous SYCL exception:\n" << e.what())