#!/usr/bin/env python3

# SPDX-License-Identifier: LGPL-3.0-or-later
# Copyright (C) 2023 Chao Peng
'''
    A script to run a 2D (eta, phi) scan using Geant4 kernel.
    Scan results are collected in a CSV format file.
'''

import os
import sys
import errno
import argparse
import pandas as pd
import numpy as np
from io import StringIO
from collections import OrderedDict as odict
from wurlitzer import pipes

import DDG4
import g4units

# pd.set_option('display.max_rows', 1000)
PROGRESS_STEP = 1


# a class to simplify the material scan, avoid the use of multipe instances of it
class g4MaterialScanner:
    def __init__(self, compact, phy_list='QGSP_BERT'):
        kernel = DDG4.Kernel()
        kernel.loadGeometry(compact)
        DDG4.Core.setPrintFormat(str("%-32s %6s %s"))
        self.kernel = kernel

        # use DDG4 Geant4 wrapper to simplify the setup
        geant4 = DDG4.Geant4(kernel)
        geant4.setupCshUI(ui=None)
        for i in geant4.description.detectors():
            o = DDG4.DetElement(i.second.ptr())
            sd = geant4.description.sensitiveDetector(o.name())
            if sd.isValid():
                typ = sd.type()
                if typ in geant4.sensitive_types:
                    geant4.setupDetector(o.name(), geant4.sensitive_types[typ])
                else:
                    logger.error('+++  %-32s type:%-12s  --> Unknown Sensitive type: %s', o.name(), typ, typ)
                    sys.exit(errno.EINVAL)
        geant4.setupPhysics(phy_list)

        # save the gun for micromanagement later
        self.gun = geant4.setupGun('Scanner',
                                   Standalone=True,
                                   particle='geantino',
                                   energy=20*g4units.GeV,
                                   position='(0, 0, 0)',
                                   direction='(0, 0, 1)',
                                   multiplicity=1,
                                   isotrop=False)

        scan = DDG4.SteppingAction(self.kernel, 'Geant4MaterialScanner/MaterialScan')
        self.kernel.steppingAction().adopt(scan)
        self.kernel.configure()
        self.kernel.initialize()
        self.kernel.NumEvents = 1

    def __del__(self):
        self.kernel.terminate()

    def scan(self, position, direction, phy_list='QGSP_BERT'):
        self.gun.position = '({},{},{})'.format(*position)
        self.gun.direction = '({},{},{})'.format(*direction)
        with pipes() as (out, err):
            self.kernel.run()
        return self.parse_scan_output(out.getvalue())

    # A parser function to convert the output from Gean4MaterialScanner to a pandas dataframe
    @staticmethod
    def parse_scan_output(output):
        # find material scan lines
        lines = []
        add_line = False

        for l in output.split('\n'):
            if add_line:
                lines.append(l.strip())
            if l.strip().startswith('MaterialScan'):
                add_line = True


        # NOTE: the code below depends on the output from DDG4::Geant4MaterialScanner
        # it is valid on 09/15/2023
        scans = []
        first_digit = False
        for i, l in enumerate(lines):
            line = l.strip('| ')
            if not first_digit and not line[:1].isdigit():
                continue
            first_digit = True
            if not line[:1].isdigit():
                break
            # break coordinates for endpoint, which has a format of (x, y, z)
            scans.append(line.strip('| ').translate({ord(i): None for i in '()'}).replace(',', ' '))

        cols = [
                'material', 'Z', 'A', 'density',
                'rad_length', 'int_length', 'thickness', 'path_length',
                'int_X0', 'int_lambda', 'end_x', 'end_y', 'end_z'
                ]

        dft = pd.read_csv(StringIO('\n'.join(scans)), sep='\s+', header=None, index_col=0, names=cols)
        print(dft)
        return dft.astype({key: np.float64 for key in cols[1:]})


'''
    A helper function to convert a string (<min>[:<max>[:<step>]]) to an array
'''
def args_array(arg, step=1, include_end=True):
    vals = [float(x.strip()) for x in arg.split(':')]
    # empty or only one value
    if len(vals) < 2:
        return np.array(vals)
    # has step input
    if len(vals) > 2:
        step = vals[2]
    # inclusion of the endpoint (max)
    if include_end:
        vals[1] += step
    return np.arange(vals[0], vals[1], step)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
            prog='g4MaterialScan_to_csv',
            description = 'A python script to scan materials of a DD4Hep geometry in geant4 kernel.'
                        + '\n       ' # 7 spaces for "usage: "
                        + 'The scanned results are saved in a CSV file.'
            )
    parser.add_argument(
            '-c', '--compact', dest='compact', required=True,
            help='Top-level xml file of the detector description.'
            )
    parser.add_argument(
            '-o', '--output', default='g4mat_scan.csv',
            help='Path of the output csv file. Support python formatting of args, e.g., g4mat_scan_{eta}_{phi}.csv.'
            )
    parser.add_argument(
            '--start-point', default='0,0,0',
            help='Start point of the scan, use the format \"x,y,z\", unit is cm.'
            )
    parser.add_argument(
            '--eta', default='-4.0:4.0:0.1',
            help='Eta range, in the format of \"<min>[:<max>[:<step>]]\".'
            )
    parser.add_argument(
            '--eta-values', default=None,
            help='a list of eta values, separated by \",\", this option overwrites --eta.'
            )
    parser.add_argument(
            '--phi', default='0:30:1',
            help='Phi angle range, in the format of \"<min>[:<max>[:<step>]]\" (degree).'
            )
    parser.add_argument(
            '--phi-values', default=None,
            help='a list of phi values, separated by \",\", this option overwrites --phi.'
            )
    parser.add_argument(
            '--mat-buffer-size', type=int, default=50,
            help='Maximum number of materials included in the aggregated output.'
            )
    parser.add_argument(
            '--raw-output', action='store_true',
            help='Turn on to save the raw outputs from scan.'
            )
    parser.add_argument(
            '--sep', default='\t',
            help='Seperator for the CSV file.'
            )
    args = parser.parse_args()

    if not os.path.exists(args.compact):
        print('Cannot find compact file {}'.format(args.compact))
        exit(-1)

    start_point = np.array([float(v.strip()) for v in args.start_point.split(',')])
    if args.eta_values is not None:
        etas = np.array([float(xval.strip()) for xval in args.eta_values.split(',')])
    else:
        etas = args_array(args.eta)

    if args.phi_values is not None:
        phis = np.array([float(xval.strip()) for xval in args.phi_values.split(',')])
    else:
        phis = args_array(args.phi)
    # sanity check
    if not len(phis):
        print('No phi values from the input {}, aborted!'.format(args.phi))
        exit(-1)

    eta_phi = []
    mats_indices = odict()
    # should exist in the scan output as int_{value_type}
    value_types = ['X0', 'lambda']
    # a data buffer for the X0 values of ((eta, phi), materials, (X0, Lambda))
    data = np.zeros(shape=(len(etas)*len(phis), args.mat_buffer_size, len(value_types)))

    # scan over eta and phi
    scanner = g4MaterialScanner(args.compact)
    print('Scanning {:d} eta values in [{:.2f}, {:.2f}] and {:d} phi values in [{:.2f}, {:.2f}] (degree).'\
          .format(len(etas), etas[0], etas[-1], len(phis), phis[0], phis[-1]))
    for i, eta in enumerate(etas):
        for j, phi in enumerate(phis):

            nlines = i*len(phis) + j
            if nlines % PROGRESS_STEP == 0:
                print('Scanned {:d}/{:d} lines.'.format(nlines, len(etas)*len(phis)), end='\r', flush=True)

            # scan
            direction = (np.cos(phi/180.*np.pi), np.sin(phi/180.*np.pi), np.sinh(eta))
            dfa = scanner.scan(start_point, direction)
            for vt in value_types:
                dfa.loc[:, vt] = dfa['int_{}'.format(vt)].diff(1).fillna(dfa['int_{}'.format(vt)])

            if args.raw_output:
                dfa.to_csv('scan_raw_eta={:.3f}_phi={:.3f}.csv'.format(eta, phi), sep=args.sep, float_format='%g')
            # group by materials
            single_scan = dfa.groupby('material')[value_types].sum()
            # print(single_scan)
            for mat, xvals in single_scan.iterrows():
                if mat not in mats_indices:
                    if len(mats_indices) >= args.mat_buffer_size:
                        print('Number of materials exceeds MAT_BUFFER_SIZE ({:d}), dropped material {}.'.format(args.mat_buffer_size, mat))
                        print('Hint: increase the buffer size with --mat-buffer-size.')
                        continue
                    mats_indices[mat] = len(mats_indices)
                k = mats_indices.get(mat)
                data[len(eta_phi), k] = xvals.values
            eta_phi.append((eta, phi))
    print('Scanned {:d}/{:d} lines.'.format(len(etas)*len(phis), len(etas)*len(phis)))

    # save results
    result = dict()
    for i, vt in enumerate(value_types):
        indices = pd.MultiIndex.from_tuples(eta_phi, names=['eta', 'phi'])
        result[vt] = pd.DataFrame(columns=mats_indices.keys(), index=indices, data=data[:, :len(mats_indices), i])
    out_path = args.output.format(**vars(args))
    pd.concat(result, names=['value_type']).to_csv(out_path, sep=args.sep, float_format='%g')
    print('Scanned results saved to \"{}\"'.format(out_path))
