diff --git a/Snakefile b/Snakefile index 17fa780..3c7d486 100644 --- a/Snakefile +++ b/Snakefile @@ -10,7 +10,7 @@ This includes: - the same for partition level tiers """ -import pathlib, os, json, sys +import pathlib, os, json, sys, glob import scripts.util as ds from scripts.util.pars_loading import pars_catalog from scripts.util.patterns import get_pattern_tier_raw @@ -21,6 +21,8 @@ from scripts.util.utils import ( chan_map_path, filelist_path, metadata_path, + tmp_log_path, + pars_path, ) from datetime import datetime from collections import OrderedDict @@ -41,68 +43,59 @@ part = ds.dataset_file(setup, os.path.join(configs, "partitions.json")) basedir = workflow.basedir -include: "rules/common.smk" -include: "rules/main.smk" - - -localrules: - gen_filelist, - autogen_output, - - -ds.pars_key_resolve.write_par_catalog( - ["-*-*-*-cal"], - os.path.join(pars_path(setup), "pht", "validity.jsonl"), - get_pattern_tier_raw(setup), - {"cal": ["par_pht"], "lar": ["par_pht"]}, -) +wildcard_constraints: + experiment="\w+", + period="p\d{2}", + run="r\d{3}", + datatype="\w{3}", + timestamp="\d{8}T\d{6}Z", +include: "rules/common.smk" +include: "rules/main.smk" include: "rules/tcm.smk" include: "rules/dsp.smk" +include: "rules/psp.smk" include: "rules/hit.smk" include: "rules/pht.smk" include: "rules/evt.smk" include: "rules/skm.smk" include: "rules/blinding_calibration.smk" +include: "rules/qc_phy.smk" + + +localrules: + gen_filelist, + autogen_output, onstart: print("Starting workflow") - shell(f"rm {pars_path(setup)}/dsp/validity.jsonl || true") - shell(f"rm {pars_path(setup)}/hit/validity.jsonl || true") - shell(f"rm {pars_path(setup)}/pht/validity.jsonl || true") - shell(f"rm {pars_path(setup)}/raw/validity.jsonl || true") - ds.pars_key_resolve.write_par_catalog( - ["-*-*-*-cal"], - os.path.join(pars_path(setup), "raw", "validity.jsonl"), - get_pattern_tier_raw(setup), - {"cal": ["par_raw"]}, - ) - ds.pars_key_resolve.write_par_catalog( - ["-*-*-*-cal"], - os.path.join(pars_path(setup), "dsp", "validity.jsonl"), - get_pattern_tier_raw(setup), - {"cal": ["par_dsp"], "lar": ["par_dsp"]}, - ) + if os.path.isfile(os.path.join(pars_path(setup), "hit", "validity.jsonl")): + os.remove(os.path.join(pars_path(setup), "hit", "validity.jsonl")) + + ds.pars_key_resolve.write_par_catalog( ["-*-*-*-cal"], os.path.join(pars_path(setup), "hit", "validity.jsonl"), get_pattern_tier_raw(setup), {"cal": ["par_hit"], "lar": ["par_hit"]}, ) + + if os.path.isfile(os.path.join(pars_path(setup), "dsp", "validity.jsonl")): + os.remove(os.path.join(pars_path(setup), "dsp", "validity.jsonl")) ds.pars_key_resolve.write_par_catalog( ["-*-*-*-cal"], - os.path.join(pars_path(setup), "pht", "validity.jsonl"), + os.path.join(pars_path(setup), "dsp", "validity.jsonl"), get_pattern_tier_raw(setup), - {"cal": ["par_pht"], "lar": ["par_pht"]}, + {"cal": ["par_dsp"], "lar": ["par_dsp"]}, ) onsuccess: from snakemake.report import auto_report - rep_dir = f"{log_path(setup)}/report-{datetime.strftime(datetime.utcnow(), '%Y%m%dT%H%M%SZ')}" + rep_dir = f"{log_path(setup)}/report-{datetime.strftime(datetime.utcnow() , '%Y%m%dT%H%M%SZ')}" pathlib.Path(rep_dir).mkdir(parents=True, exist_ok=True) # auto_report(workflow.persistence.dag, f"{rep_dir}/report.html") with open(os.path.join(rep_dir, "dag.txt"), "w") as f: @@ -112,8 +105,32 @@ onsuccess: f.writelines(str(workflow.persistence.dag.rule_dot())) # shell(f"cat {rep_dir}/rg.txt | dot -Tpdf > {rep_dir}/rg.pdf") print("Workflow finished, no error") - shell("rm *.gen || true") - shell(f"rm {filelist_path(setup)}/* || true") + + # remove .gen files + files = glob.glob("*.gen") + for file in files: + if os.path.isfile(file): + os.remove(file) + + # remove filelists + files = glob.glob(os.path.join(filelist_path(setup), "*")) + for file in files: + if os.path.isfile(file): + os.remove(file) + if os.path.exists(filelist_path(setup)): + os.rmdir(filelist_path(setup)) + + # remove logs + files = glob.glob(os.path.join(tmp_log_path(setup), "*", "*.log")) + for file in files: + if os.path.isfile(file): + os.remove(file) + dirs = glob.glob(os.path.join(tmp_log_path(setup), "*")) + for d in dirs: + if os.path.isdir(d): + os.rmdir(d) + if os.path.exists(tmp_log_path(setup)): + os.rmdir(tmp_log_path(setup)) # Placeholder, can email or maybe put message in slack diff --git a/Snakefile-build-raw b/Snakefile-build-raw index 02362c6..edbc7d8 100644 --- a/Snakefile-build-raw +++ b/Snakefile-build-raw @@ -40,6 +40,14 @@ meta = metadata_path(setup) basedir = workflow.basedir +wildcard_constraints: + experiment="\w+", + period="p\d{2}", + run="r\d{3}", + datatype="\w{3}", + timestamp="\d{8}T\d{6}Z", + + localrules: gen_filelist, autogen_output, diff --git a/rules/common.smk b/rules/common.smk index b5fba4d..6cb5d40 100644 --- a/rules/common.smk +++ b/rules/common.smk @@ -11,6 +11,7 @@ from scripts.util.patterns import ( get_pattern_tier_raw, get_pattern_plts_tmp_channel, ) +from scripts.util import ProcessingFileKey def read_filelist(wildcards): @@ -39,6 +40,15 @@ def read_filelist_cal(wildcards, tier): return files +def read_filelist_fft(wildcards, tier): + label = f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-fft" + with checkpoints.gen_filelist.get(label=label, tier=tier, extension="file").output[ + 0 + ].open() as f: + files = f.read().splitlines() + return files + + def read_filelist_pars_cal_channel(wildcards, tier): """ This function will read the filelist of the channels and return a list of dsp files one for each channel @@ -99,3 +109,63 @@ def get_pattern(tier): return get_pattern_tier_daq(setup) else: return get_pattern_tier_raw(setup) + + +def set_last_rule_name(workflow, new_name): + """Sets the name of the most recently created rule to be `new_name`. + Useful when creating rules dynamically (i.e. unnamed). + + Warning + ------- + This could mess up the workflow. Use at your own risk. + """ + rules = workflow._rules + last_key = next(reversed(rules)) + assert last_key == rules[last_key].name + + rules[new_name] = rules.pop(last_key) + rules[new_name].name = new_name + + if workflow.default_target == last_key: + workflow.default_target = new_name + + if last_key in workflow._localrules: + workflow._localrules.remove(last_key) + workflow._localrules.add(new_name) + + workflow.check_localrules() + + +def get_svm_file(wildcards, tier, name): + par_overwrite_file = os.path.join(par_overwrite_path(setup), tier, "validity.jsonl") + pars_files_overwrite = pars_catalog.get_calib_files( + par_overwrite_file, wildcards.timestamp + ) + for pars_file in pars_files_overwrite: + if name in pars_file: + return os.path.join(par_overwrite_path(setup), tier, pars_file) + raise ValueError(f"Could not find model in {pars_files_overwrite}") + + +def get_overwrite_file(tier, wildcards=None, timestamp=None, name=None): + par_overwrite_file = os.path.join(par_overwrite_path(setup), tier, "validity.jsonl") + if timestamp is not None: + pars_files_overwrite = pars_catalog.get_calib_files( + par_overwrite_file, timestamp + ) + else: + pars_files_overwrite = pars_catalog.get_calib_files( + par_overwrite_file, wildcards.timestamp + ) + if name is None: + fullname = f"{tier}-overwrite.json" + else: + fullname = f"{tier}_{name}-overwrite.json" + out_files = [] + for pars_file in pars_files_overwrite: + if fullname in pars_file: + out_files.append(os.path.join(par_overwrite_path(setup), tier, pars_file)) + if len(out_files) == 0: + raise ValueError(f"Could not find name in {pars_files_overwrite}") + else: + return out_files diff --git a/rules/dsp.smk b/rules/dsp.smk index fc8ecf5..d44a6db 100644 --- a/rules/dsp.smk +++ b/rules/dsp.smk @@ -7,6 +7,7 @@ Snakemake rules for processing dsp tier. This is done in 4 steps: """ from scripts.util.pars_loading import pars_catalog +from scripts.util.utils import par_dsp_path from scripts.util.patterns import ( get_pattern_pars_tmp_channel, get_pattern_plts_tmp_channel, @@ -19,6 +20,8 @@ from scripts.util.patterns import ( get_pattern_pars_tmp, get_pattern_log, get_pattern_pars, + get_pattern_pars_overwrite, + get_pattern_pars_svm, ) @@ -27,7 +30,7 @@ rule build_pars_dsp_tau: files=os.path.join( filelist_path(setup), "all-{experiment}-{period}-{run}-cal-raw.filelist" ), - tcm_files=lambda wildcards: read_filelist_cal(wildcards, "tcm"), + pulser=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), params: timestamp="{timestamp}", datatype="cal", @@ -51,10 +54,46 @@ rule build_pars_dsp_tau: "--channel {params.channel} " "--plot_path {output.plots} " "--output_file {output.decay_const} " - "--tcm_files {input.tcm_files} " + "--pulser_file {input.pulser} " "--raw_files {input.files}" +rule build_pars_event_selection: + input: + files=os.path.join( + filelist_path(setup), "all-{experiment}-{period}-{run}-cal-raw.filelist" + ), + pulser_file=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), + database=get_pattern_pars_tmp_channel(setup, "dsp", "decay_constant"), + raw_cal=get_blinding_curve_file, + params: + timestamp="{timestamp}", + datatype="cal", + channel="{channel}", + output: + peak_file=temp(get_pattern_pars_tmp_channel(setup, "dsp", "peaks", "lh5")), + log: + get_pattern_log_channel(setup, "par_dsp_event_selection"), + group: + "par-dsp" + resources: + runtime=300, + mem_swap=70, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_dsp_event_selection.py')} " + "--configs {configs} " + "--log {log} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--peak_file {output.peak_file} " + "--pulser_file {input.pulser_file} " + "--decay_const {input.database} " + "--raw_cal {input.raw_cal} " + "--raw_filelist {input.files}" + + # This rule builds the optimal energy filter parameters for the dsp using fft files rule build_pars_dsp_nopt: input: @@ -93,23 +132,60 @@ rule build_pars_dsp_nopt: "--raw_filelist {input.files}" -# This rule builds the optimal energy filter parameters for the dsp using calibration dsp files -rule build_pars_dsp_eopt: +# This rule builds the dplms energy filter for the dsp using fft and cal files +rule build_pars_dsp_dplms: input: - files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-raw.filelist" - ), - tcm_filelist=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" + fft_files=os.path.join( + filelist_path(setup), "all-{experiment}-{period}-{run}-fft-raw.filelist" ), - decay_const=get_pattern_pars_tmp_channel(setup, "dsp", "noise_optimization"), + peak_file=get_pattern_pars_tmp_channel(setup, "dsp", "peaks", "lh5"), + database=get_pattern_pars_tmp_channel(setup, "dsp", "noise_optimization"), inplots=get_pattern_plts_tmp_channel(setup, "dsp", "noise_optimization"), params: timestamp="{timestamp}", datatype="cal", channel="{channel}", output: - dsp_pars=temp(get_pattern_pars_tmp_channel(setup, "dsp")), + dsp_pars=temp(get_pattern_pars_tmp_channel(setup, "dsp", "dplms")), + lh5_path=temp( + get_pattern_pars_tmp_channel(setup, "dsp", "dplms", extension="lh5") + ), + plots=temp(get_pattern_plts_tmp_channel(setup, "dsp", "dplms")), + log: + get_pattern_log_channel(setup, "pars_dsp_dplms"), + group: + "par-dsp" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_dsp_dplms.py')} " + "--fft_raw_filelist {input.fft_files} " + "--peak_file {input.peak_file} " + "--database {input.database} " + "--inplots {input.inplots} " + "--configs {configs} " + "--log {log} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--dsp_pars {output.dsp_pars} " + "--lh5_path {output.lh5_path} " + "--plot_path {output.plots} " + + +# This rule builds the optimal energy filter parameters for the dsp using calibration dsp files +rule build_pars_dsp_eopt: + input: + peak_file=get_pattern_pars_tmp_channel(setup, "dsp", "peaks", "lh5"), + decay_const=get_pattern_pars_tmp_channel(setup, "dsp", "dplms"), + inplots=get_pattern_plts_tmp_channel(setup, "dsp", "dplms"), + params: + timestamp="{timestamp}", + datatype="cal", + channel="{channel}", + output: + dsp_pars=temp(get_pattern_pars_tmp_channel(setup, "dsp_eopt")), qbb_grid=temp( get_pattern_pars_tmp_channel(setup, "dsp", "objects", extension="pkl") ), @@ -128,8 +204,7 @@ rule build_pars_dsp_eopt: "--datatype {params.datatype} " "--timestamp {params.timestamp} " "--channel {params.channel} " - "--raw_filelist {input.files} " - "--tcm_filelist {input.tcm_filelist} " + "--peak_file {input.peak_file} " "--inplots {input.inplots} " "--decay_const {input.decay_const} " "--plot_path {output.plots} " @@ -137,13 +212,68 @@ rule build_pars_dsp_eopt: "--final_dsp_pars {output.dsp_pars}" -rule build_pars_dsp: +rule build_svm_dsp: + input: + hyperpars=lambda wildcards: get_svm_file(wildcards, "dsp", "svm_hyperpars"), + train_data=lambda wildcards: get_svm_file( + wildcards, "dsp", "svm_hyperpars" + ).replace("hyperpars.json", "train.lh5"), + output: + dsp_pars=get_pattern_pars(setup, "dsp", "svm", "pkl"), + log: + get_pattern_log(setup, "pars_dsp_svm").replace("{datatype}", "cal"), + group: + "par-dsp-svm" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_dsp_build_svm.py')} " + "--log {log} " + "--train_data {input.train_data} " + "--train_hyperpars {input.hyperpars} " + "--output_file {output.dsp_pars}" + + +rule build_pars_dsp_svm: + input: + dsp_pars=get_pattern_pars_tmp_channel(setup, "dsp_eopt"), + svm_file=get_pattern_pars(setup, "dsp", "svm", "pkl"), + output: + dsp_pars=temp(get_pattern_pars_tmp_channel(setup, "dsp")), + log: + get_pattern_log_channel(setup, "pars_dsp_svm"), + group: + "par-dsp" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_dsp_svm.py')} " + "--log {log} " + "--input_file {input.dsp_pars} " + "--output_file {output.dsp_pars} " + "--svm_file {input.svm_file}" + + +rule build_plts_dsp: input: - lambda wildcards: read_filelist_pars_cal_channel(wildcards, "dsp"), lambda wildcards: read_filelist_plts_cal_channel(wildcards, "dsp"), + output: + get_pattern_plts(setup, "dsp"), + group: + "merge-dsp" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +rule build_pars_dsp_objects: + input: lambda wildcards: read_filelist_pars_cal_channel(wildcards, "dsp_objects_pkl"), output: - get_pattern_pars(setup, "dsp", check_in_cycle=check_in_cycle), get_pattern_pars( setup, "dsp", @@ -151,20 +281,75 @@ rule build_pars_dsp: extension="dir", check_in_cycle=check_in_cycle, ), - get_pattern_plts(setup, "dsp"), group: "merge-dsp" shell: "{swenv} python3 -B " - f"{workflow.source_path('../scripts/merge_channels.py')} " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +rule build_pars_dsp_db: + input: + lambda wildcards: read_filelist_pars_cal_channel(wildcards, "dsp"), + output: + temp( + get_pattern_pars_tmp( + setup, + "dsp", + datatype="cal", + ) + ), + group: + "merge-dsp" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " "--input {input} " "--output {output} " +rule build_pars_dsp: + input: + in_files=lambda wildcards: read_filelist_pars_cal_channel( + wildcards, "dsp_dplms_lh5" + ), + in_db=get_pattern_pars_tmp( + setup, + "dsp", + datatype="cal", + ), + plts=get_pattern_plts(setup, "dsp"), + objects=get_pattern_pars( + setup, + "dsp", + name="objects", + extension="dir", + check_in_cycle=check_in_cycle, + ), + output: + out_file=get_pattern_pars( + setup, + "dsp", + extension="lh5", + check_in_cycle=check_in_cycle, + ), + out_db=get_pattern_pars(setup, "dsp", check_in_cycle=check_in_cycle), + group: + "merge-dsp" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--output {output.out_file} " + "--in_db {input.in_db} " + "--out_db {output.out_db} " + "--input {input.in_files} " + + rule build_dsp: input: raw_file=get_pattern_tier_raw(setup), - tcm_file=get_pattern_tier_tcm(setup), pars_file=ancient( lambda wildcards: pars_catalog.get_par_file( setup, wildcards.timestamp, "dsp" @@ -182,7 +367,7 @@ rule build_dsp: "tier-dsp" resources: runtime=300, - mem_swap=50, + mem_swap=25, shell: "{swenv} python3 -B " f"{workflow.source_path('../scripts/build_dsp.py')} " @@ -193,4 +378,4 @@ rule build_dsp: "--input {input.raw_file} " "--output {output.tier_file} " "--db_file {output.db_file} " - "--pars_file {input.pars_file}" + "--pars_file {input.pars_file} " diff --git a/rules/evt.smk b/rules/evt.smk index 9cc6e13..7454957 100644 --- a/rules/evt.smk +++ b/rules/evt.smk @@ -11,70 +11,71 @@ from scripts.util.patterns import ( get_pattern_tier, get_pattern_log, get_pattern_pars, + get_pattern_log_concat, ) -rule build_evt: - input: - dsp_file=get_pattern_tier_dsp(setup), - hit_file=get_pattern_tier_hit(setup), - tcm_file=get_pattern_tier_tcm(setup), - output: - evt_file=get_pattern_tier(setup, "evt", check_in_cycle=check_in_cycle), - params: - timestamp="{timestamp}", - datatype="{datatype}", - tier="evt", - log: - get_pattern_log(setup, "tier_evt"), - group: - "tier-evt" - resources: - runtime=300, - mem_swap=70, - shell: - "{swenv} python3 -B " - f"{workflow.source_path('../scripts/build_evt.py')} " - "--configs {configs} " - "--metadata {meta} " - "--log {log} " - "--tier {params.tier} " - "--datatype {params.datatype} " - "--timestamp {params.timestamp} " - "--hit_file {input.hit_file} " - "--tcm_file {input.tcm_file} " - "--dsp_file {input.dsp_file} " - "--output {output.evt_file} " +for tier in ("evt", "pet"): + rule: + input: + dsp_file=( + get_pattern_tier_dsp(setup) + if tier == "evt" + else get_pattern_tier_psp(setup) + ), + hit_file=( + get_pattern_tier_hit(setup) + if tier == "evt" + else get_pattern_tier_pht(setup) + ), + tcm_file=get_pattern_tier_tcm(setup), + output: + evt_file=get_pattern_tier(setup, tier, check_in_cycle=check_in_cycle), + params: + timestamp="{timestamp}", + datatype="{datatype}", + tier=tier, + log: + get_pattern_log(setup, f"tier_{tier}"), + group: + "tier-evt" + resources: + runtime=300, + mem_swap=50, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/build_evt.py')} " + "--configs {configs} " + "--metadata {meta} " + "--log {log} " + "--tier {params.tier} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--hit_file {input.hit_file} " + "--tcm_file {input.tcm_file} " + "--dsp_file {input.dsp_file} " + "--output {output.evt_file} " -rule build_pet: - input: - dsp_file=get_pattern_tier_dsp(setup), - hit_file=get_pattern_tier_pht(setup), - tcm_file=get_pattern_tier_tcm(setup), - output: - evt_file=get_pattern_tier(setup, "pet", check_in_cycle=check_in_cycle), - params: - timestamp="{timestamp}", - datatype="{datatype}", - tier="pet", - log: - get_pattern_log(setup, "tier_pet"), - group: - "tier-evt" - resources: - runtime=300, - mem_swap=70, - shell: - "{swenv} python3 -B " - f"{workflow.source_path('../scripts/build_evt.py')} " - "--configs {configs} " - "--log {log} " - "--tier {params.tier} " - "--datatype {params.datatype} " - "--timestamp {params.timestamp} " - "--metadata {meta} " - "--hit_file {input.hit_file} " - "--tcm_file {input.tcm_file} " - "--dsp_file {input.dsp_file} " - "--output {output.evt_file} " + set_last_rule_name(workflow, f"build_{tier}") + + rule: + wildcard_constraints: + timestamp="(?!\d{8}T\d{6}Z)", + input: + lambda wildcards: sorted(read_filelist_phy(wildcards, tier)), + output: + get_pattern_tier(setup, f"{tier}_concat", check_in_cycle=check_in_cycle), + params: + timestamp="all", + datatype="{datatype}", + log: + get_pattern_log_concat(setup, f"tier_{tier}_concat"), + group: + "tier-evt" + shell: + "{swenv} lh5concat --verbose --overwrite " + "--output {output} " + "-- {input} &> {log}" + + set_last_rule_name(workflow, f"concat_{tier}") diff --git a/rules/hit.smk b/rules/hit.smk index de918b3..26d9acb 100644 --- a/rules/hit.smk +++ b/rules/hit.smk @@ -21,18 +21,54 @@ from scripts.util.patterns import ( ) +# This rule builds the qc using the calibration dsp files and fft files +rule build_qc: + input: + files=lambda wildcards: read_filelist_cal(wildcards, "dsp"), + fft_files=lambda wildcards: read_filelist_fft(wildcards, "dsp"), + pulser=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), + params: + timestamp="{timestamp}", + datatype="cal", + channel="{channel}", + output: + qc_file=temp(get_pattern_pars_tmp_channel(setup, "hit", "qc")), + plot_file=temp(get_pattern_plts_tmp_channel(setup, "hit", "qc")), + log: + get_pattern_log_channel(setup, "pars_hit_qc"), + group: + "par-hit" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_hit_qc.py')} " + "--log {log} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--configs {configs} " + "--plot_path {output.plot_file} " + "--save_path {output.qc_file} " + "--pulser_file {input.pulser} " + "--cal_files {input.files} " + "--fft_files {input.fft_files} " + + # This rule builds the energy calibration using the calibration dsp files rule build_energy_calibration: input: - files=lambda wildcards: read_filelist_cal(wildcards, "dsp"), - tcm_filelist=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" + files=os.path.join( + filelist_path(setup), "all-{experiment}-{period}-{run}-cal-dsp.filelist" ), + pulser=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), ctc_dict=ancient( lambda wildcards: pars_catalog.get_par_file( setup, wildcards.timestamp, "dsp" ) ), + inplots=get_pattern_plts_tmp_channel(setup, "hit", "qc"), + in_hit_dict=get_pattern_pars_tmp_channel(setup, "hit", "qc"), params: timestamp="{timestamp}", datatype="cal", @@ -59,11 +95,14 @@ rule build_energy_calibration: "--timestamp {params.timestamp} " "--channel {params.channel} " "--configs {configs} " + "--metadata {meta} " "--plot_path {output.plot_file} " "--results_path {output.results_file} " "--save_path {output.ecal_file} " + "--inplot_dict {input.inplots} " + "--in_hit_dict {input.in_hit_dict} " "--ctc_dict {input.ctc_dict} " - "--tcm_filelist {input.tcm_filelist} " + "--pulser_file {input.pulser} " "--files {input.files}" @@ -73,9 +112,7 @@ rule build_aoe_calibration: files=os.path.join( filelist_path(setup), "all-{experiment}-{period}-{run}-cal-dsp.filelist" ), - tcm_filelist=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" - ), + pulser=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), ecal_file=get_pattern_pars_tmp_channel(setup, "hit", "energy_cal"), eres_file=get_pattern_pars_tmp_channel( setup, "hit", "energy_cal_objects", extension="pkl" @@ -112,7 +149,7 @@ rule build_aoe_calibration: "--eres_file {input.eres_file} " "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " - "--tcm_filelist {input.tcm_filelist} " + "--pulser_file {input.pulser} " "--ecal_file {input.ecal_file} " "{input.files}" @@ -123,9 +160,7 @@ rule build_lq_calibration: files=os.path.join( filelist_path(setup), "all-{experiment}-{period}-{run}-cal-dsp.filelist" ), - tcm_filelist=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" - ), + pulser=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), ecal_file=get_pattern_pars_tmp_channel(setup, "hit", "aoe_cal"), eres_file=get_pattern_pars_tmp_channel( setup, "hit", "aoe_cal_objects", extension="pkl" @@ -160,33 +195,33 @@ rule build_lq_calibration: "--eres_file {input.eres_file} " "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " - "--tcm_filelist {input.tcm_filelist} " + "--pulser_file {input.pulser} " "--ecal_file {input.ecal_file} " "{input.files}" -rule build_pars_hit: - input: - lambda wildcards: read_filelist_pars_cal_channel(wildcards, "hit"), - lambda wildcards: read_filelist_plts_cal_channel(wildcards, "hit"), - lambda wildcards: read_filelist_pars_cal_channel(wildcards, "hit_objects_pkl"), - output: - get_pattern_pars(setup, "hit", check_in_cycle=check_in_cycle), - get_pattern_pars( - setup, - "hit", - name="objects", - extension="dir", - check_in_cycle=check_in_cycle, - ), - get_pattern_plts(setup, "hit"), - group: - "merge-hit" - shell: - "{swenv} python3 -B " - f"{workflow.source_path('../scripts/merge_channels.py')} " - "--input {input} " - "--output {output} " +# rule build_pars_hit: +# input: +# lambda wildcards: read_filelist_pars_cal_channel(wildcards, "hit"), +# lambda wildcards: read_filelist_plts_cal_channel(wildcards, "hit"), +# lambda wildcards: read_filelist_pars_cal_channel(wildcards, "hit_objects_pkl"), +# output: +# get_pattern_pars(setup, "hit", check_in_cycle=check_in_cycle), +# get_pattern_pars( +# setup, +# "hit", +# name="objects", +# extension="dir", +# check_in_cycle=check_in_cycle, +# ), +# get_pattern_plts(setup, "hit"), +# group: +# "merge-hit" +# shell: +# "{swenv} python3 -B " +# f"{workflow.source_path('../scripts/merge_channels.py')} " +# "--input {input} " +# "--output {output} " rule build_hit: diff --git a/rules/main.smk b/rules/main.smk index b67ea46..86d940a 100644 --- a/rules/main.smk +++ b/rules/main.smk @@ -29,10 +29,10 @@ rule autogen_output: gen_output="{label}-{tier}.gen", summary_log=f"{log_path(setup)}/summary-" + "{label}-{tier}" - + f"-{datetime.strftime(datetime.utcnow(), '%Y%m%dT%H%M%SZ')}.log", + + f"-{datetime.strftime(datetime.utcnow() , '%Y%m%dT%H%M%SZ')}.log", warning_log=f"{log_path(setup)}/warning-" + "{label}-{tier}" - + f"-{datetime.strftime(datetime.utcnow(), '%Y%m%dT%H%M%SZ')}.log", + + f"-{datetime.strftime(datetime.utcnow() , '%Y%m%dT%H%M%SZ')}.log", params: log_path=tmp_log_path(setup), tmp_par_path=os.path.join(tmp_par_path(setup), "*_db.json"), diff --git a/rules/pht.smk b/rules/pht.smk index 71f9acd..e3efae7 100644 --- a/rules/pht.smk +++ b/rules/pht.smk @@ -7,6 +7,7 @@ Snakemake rules for processing pht (partition hit) tier data. This is done in 4 """ from scripts.util.pars_loading import pars_catalog +from scripts.util.create_pars_keylist import pars_key_resolve from scripts.util.utils import filelist_path, par_pht_path, set_last_rule_name from scripts.util.patterns import ( get_pattern_pars_tmp_channel, @@ -14,24 +15,199 @@ from scripts.util.patterns import ( get_pattern_log_channel, get_pattern_par_pht, get_pattern_plts, - get_pattern_tier_dsp, get_pattern_tier, get_pattern_pars_tmp, get_pattern_log, get_pattern_pars, ) +ds.pars_key_resolve.write_par_catalog( + ["-*-*-*-cal"], + os.path.join(pars_path(setup), "pht", "validity.jsonl"), + get_pattern_tier_raw(setup), + {"cal": ["par_pht"], "lar": ["par_pht"]}, +) + +intier = "psp" + + +rule pht_checkpoint: + input: + files=lambda wildcards: read_filelist_cal(wildcards, intier), + output: + temp(get_pattern_pars_tmp_channel(setup, "pht", "check")), + shell: + "touch {output}" + + +qc_pht_rules = {} +for key, dataset in part.datasets.items(): + for partition in dataset.keys(): + + rule: + input: + cal_files=part.get_filelists(partition, key, intier), + fft_files=part.get_filelists(partition, key, intier, datatype="fft"), + pulser_files=[ + file.replace("pht", "tcm") + for file in part.get_par_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="pulser_ids", + ) + ], + check_files=part.get_par_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="check", + ), + overwrite_files=get_overwrite_file( + "pht", + timestamp=part.get_timestamp( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + ), + ), + wildcard_constraints: + channel=part.get_wildcard_constraints(partition, key), + params: + datatype="cal", + channel="{channel}" if key == "default" else key, + timestamp=part.get_timestamp( + f"{par_pht_path(setup)}/validity.jsonl", partition, key, tier="pht" + ), + output: + hit_pars=[ + temp(file) + for file in part.get_par_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="qc", + ) + ], + plot_file=[ + temp(file) + for file in part.get_plt_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="qc", + ) + ], + log: + part.get_log_file( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + "pht", + name="par_pht_qc", + ), + group: + "par-pht" + resources: + mem_swap=len(part.get_filelists(partition, key, intier)) * 15, + runtime=300, + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/pars_pht_qc.py " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--save_path {output.hit_pars} " + "--plot_path {output.plot_file} " + "--overwrite_files {input.overwrite_files} " + "--pulser_files {input.pulser_files} " + "--fft_files {input.fft_files} " + "--cal_files {input.cal_files}" + + set_last_rule_name(workflow, f"{key}-{partition}-build_pht_qc") + + if key in qc_pht_rules: + qc_pht_rules[key].append(list(workflow.rules)[-1]) + else: + qc_pht_rules[key] = [list(workflow.rules)[-1]] + + +# Merged energy and a/e supercalibrations to reduce number of rules as they have same inputs/outputs +# This rule builds the a/e calibration using the calibration dsp files for the whole partition +rule build_pht_qc: + input: + cal_files=os.path.join( + filelist_path(setup), + "all-{experiment}-{period}-{run}-cal-" + f"{intier}.filelist", + ), + fft_files=os.path.join( + filelist_path(setup), + "all-{experiment}-{period}-{run}-fft-" + f"{intier}.filelist", + ), + pulser_files=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), + check_file=get_pattern_pars_tmp_channel(setup, "pht", "check"), + overwrite_files=lambda wildcards: get_overwrite_file("pht", wildcards=wildcards), + params: + datatype="cal", + channel="{channel}", + timestamp="{timestamp}", + output: + hit_pars=temp(get_pattern_pars_tmp_channel(setup, "pht", "qc")), + plot_file=temp(get_pattern_plts_tmp_channel(setup, "pht", "qc")), + log: + get_pattern_log_channel(setup, "par_pht_qc"), + group: + "par-pht" + resources: + mem_swap=60, + runtime=300, + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/pars_pht_qc.py " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--save_path {output.hit_pars} " + "--plot_path {output.plot_file} " + "--overwrite_files {input.overwrite_files} " + "--pulser_files {input.pulser_files} " + "--fft_files {input.fft_files} " + "--cal_files {input.cal_files}" + + +fallback_qc_rule = list(workflow.rules)[-1] + +rule_order_list = [] +ordered = OrderedDict(qc_pht_rules) +ordered.move_to_end("default") +for key, items in ordered.items(): + rule_order_list += [item.name for item in items] +rule_order_list.append(fallback_qc_rule.name) +workflow._ruleorder.add(*rule_order_list) # [::-1] + # This rule builds the energy calibration using the calibration dsp files rule build_per_energy_calibration: input: - files=lambda wildcards: read_filelist_cal(wildcards, "dsp"), - tcm_filelist=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" + files=os.path.join( + filelist_path(setup), + "all-{experiment}-{period}-{run}-cal-" + f"{intier}.filelist", ), + pulser=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), + pht_dict=get_pattern_pars_tmp_channel(setup, "pht", "qc"), + inplots=get_pattern_plts_tmp_channel(setup, "pht", "qc"), ctc_dict=ancient( lambda wildcards: pars_catalog.get_par_file( - setup, wildcards.timestamp, "dsp" + setup, wildcards.timestamp, intier ) ), params: @@ -48,7 +224,7 @@ rule build_per_energy_calibration: ), plot_file=temp(get_pattern_plts_tmp_channel(setup, "pht", "energy_cal")), log: - get_pattern_log_channel(setup, "pars_pht_energy_cal"), + get_pattern_log_channel(setup, "par_pht_energy_cal"), group: "par-pht" resources: @@ -62,86 +238,34 @@ rule build_per_energy_calibration: "--channel {params.channel} " "--configs {configs} " "--tier {params.tier} " + "--metadata {meta} " "--plot_path {output.plot_file} " "--results_path {output.results_file} " "--save_path {output.ecal_file} " + "--inplot_dict {input.inplots} " + "--in_hit_dict {input.pht_dict} " "--ctc_dict {input.ctc_dict} " - "--tcm_filelist {input.tcm_filelist} " + "--pulser_file {input.pulser} " "--files {input.files}" -rule build_pars_pht: - input: - lambda wildcards: read_filelist_pars_cal_channel(wildcards, "pht"), - lambda wildcards: read_filelist_plts_cal_channel(wildcards, "pht"), - lambda wildcards: read_filelist_pars_cal_channel( - wildcards, - "pht_objects_pkl", - ), - output: - get_pattern_pars(setup, "pht", check_in_cycle=check_in_cycle), - get_pattern_pars( - setup, - "pht", - name="objects", - extension="dir", - check_in_cycle=check_in_cycle, - ), - get_pattern_plts(setup, "pht"), - group: - "merge-hit" - shell: - "{swenv} python3 -B " - f"{workflow.source_path('../scripts/merge_channels.py')} " - "--input {input} " - "--output {output} " - - -rule build_pht: - input: - dsp_file=get_pattern_tier_dsp(setup), - #hit_file = get_pattern_tier_hit(setup), - pars_file=lambda wildcards: pars_catalog.get_par_file( - setup, wildcards.timestamp, "pht" - ), - output: - tier_file=get_pattern_tier(setup, "pht", check_in_cycle=check_in_cycle), - db_file=get_pattern_pars_tmp(setup, "pht_db"), - params: - timestamp="{timestamp}", - datatype="{datatype}", - tier="pht", - log: - get_pattern_log(setup, "tier_pht"), - group: - "tier-pht" - resources: - runtime=300, - shell: - "{swenv} python3 -B " - f"{workflow.source_path('../scripts/build_hit.py')} " - "--configs {configs} " - "--log {log} " - "--tier {params.tier} " - "--datatype {params.datatype} " - "--timestamp {params.timestamp} " - "--pars_file {input.pars_file} " - "--output {output.tier_file} " - "--input {input.dsp_file} " - "--db_file {output.db_file}" - - part_pht_rules = {} for key, dataset in part.datasets.items(): for partition in dataset.keys(): - print( - part.get_wildcard_constraints(partition, key), - ) rule: input: - files=part.get_filelists(partition, key, "dsp"), - tcm_files=part.get_filelists(partition, key, "tcm"), + files=part.get_filelists(partition, key, intier), + pulser_files=[ + file.replace("pht", "tcm") + for file in part.get_par_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="pulser_ids", + ) + ], ecal_file=part.get_par_files( f"{par_pht_path(setup)}/validity.jsonl", partition, @@ -215,7 +339,7 @@ for key, dataset in part.datasets.items(): group: "par-pht" resources: - mem_swap=300, + mem_swap=len(part.get_filelists(partition, key, intier)) * 15, runtime=300, shell: "{swenv} python3 -B " @@ -226,12 +350,13 @@ for key, dataset in part.datasets.items(): "--timestamp {params.timestamp} " "--inplots {input.inplots} " "--channel {params.channel} " + "--metadata {meta} " "--fit_results {output.partcal_results} " "--eres_file {input.eres_file} " "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " "--ecal_file {input.ecal_file} " - "--tcm_filelist {input.tcm_files} " + "--pulser_files {input.pulser_files} " "--input_files {input.files}" set_last_rule_name( @@ -249,11 +374,10 @@ for key, dataset in part.datasets.items(): rule build_pht_energy_super_calibrations: input: files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-dsp.filelist" - ), - tcm_files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" + filelist_path(setup), + "all-{experiment}-{period}-{run}-cal" + f"-{intier}.filelist", ), + pulser_files=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), ecal_file=get_pattern_pars_tmp_channel(setup, "pht", "energy_cal"), eres_file=get_pattern_pars_tmp_channel( setup, "pht", "energy_cal_objects", extension="pkl" @@ -272,7 +396,7 @@ rule build_pht_energy_super_calibrations: ), plot_file=temp(get_pattern_plts_tmp_channel(setup, "pht", "partcal")), log: - get_pattern_log_channel(setup, "pars_pht_partcal"), + get_pattern_log_channel(setup, "par_pht_partcal"), group: "par-pht" resources: @@ -286,13 +410,14 @@ rule build_pht_energy_super_calibrations: "--datatype {params.datatype} " "--timestamp {params.timestamp} " "--channel {params.channel} " + "--metadata {meta} " "--inplots {input.inplots} " "--fit_results {output.partcal_results} " "--eres_file {input.eres_file} " "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " "--ecal_file {input.ecal_file} " - "--tcm_filelist {input.tcm_files} " + "--pulser_files {input.pulser_files} " "--input_files {input.files}" @@ -312,8 +437,17 @@ for key, dataset in part.datasets.items(): rule: input: - files=part.get_filelists(partition, key, "dsp"), - tcm_files=part.get_filelists(partition, key, "tcm"), + files=part.get_filelists(partition, key, intier), + pulser_files=[ + file.replace("pht", "tcm") + for file in part.get_par_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="pulser_ids", + ) + ], ecal_file=part.get_par_files( f"{par_pht_path(setup)}/validity.jsonl", partition, @@ -387,7 +521,7 @@ for key, dataset in part.datasets.items(): group: "par-pht" resources: - mem_swap=300, + mem_swap=len(part.get_filelists(partition, key, intier)) * 15, runtime=300, shell: "{swenv} python3 -B " @@ -403,7 +537,7 @@ for key, dataset in part.datasets.items(): "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " "--ecal_file {input.ecal_file} " - "--tcm_filelist {input.tcm_files} " + "--pulser_files {input.pulser_files} " "--input_files {input.files}" set_last_rule_name( @@ -421,11 +555,10 @@ for key, dataset in part.datasets.items(): rule build_pht_aoe_calibrations: input: files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-dsp.filelist" - ), - tcm_filelist=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" + filelist_path(setup), + "all-{experiment}-{period}-{run}-cal-" + f"{intier}.filelist", ), + pulser_files=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), ecal_file=get_pattern_pars_tmp_channel(setup, "pht", "partcal"), eres_file=get_pattern_pars_tmp_channel( setup, "pht", "partcal_objects", extension="pkl" @@ -444,7 +577,7 @@ rule build_pht_aoe_calibrations: ), plot_file=temp(get_pattern_plts_tmp_channel(setup, "pht", "aoecal")), log: - get_pattern_log_channel(setup, "pars_pht_aoe_cal"), + get_pattern_log_channel(setup, "par_pht_aoe_cal"), group: "par-pht" resources: @@ -464,7 +597,7 @@ rule build_pht_aoe_calibrations: "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " "--ecal_file {input.ecal_file} " - "--tcm_filelist {input.tcm_filelist} " + "--pulser_files {input.pulser_files} " "--input_files {input.files}" @@ -484,8 +617,17 @@ for key, dataset in part.datasets.items(): rule: input: - files=part.get_filelists(partition, key, "dsp"), - tcm_files=part.get_filelists(partition, key, "tcm"), + files=part.get_filelists(partition, key, intier), + pulser_files=[ + file.replace("pht", "tcm") + for file in part.get_par_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="pulser_ids", + ) + ], ecal_file=part.get_par_files( f"{par_pht_path(setup)}/validity.jsonl", partition, @@ -557,7 +699,7 @@ for key, dataset in part.datasets.items(): group: "par-pht" resources: - mem_swap=300, + mem_swap=len(part.get_filelists(partition, key, intier)) * 15, runtime=300, shell: "{swenv} python3 -B " @@ -573,7 +715,7 @@ for key, dataset in part.datasets.items(): "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " "--ecal_file {input.ecal_file} " - "--tcm_filelist {input.tcm_files} " + "--pulser_files {input.pulser_files} " "--input_files {input.files}" set_last_rule_name(workflow, f"{key}-{partition}-build_pht_lq_calibration") @@ -588,11 +730,10 @@ for key, dataset in part.datasets.items(): rule build_pht_lq_calibration: input: files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-dsp.filelist" - ), - tcm_filelist=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-cal-tcm.filelist" + filelist_path(setup), + "all-{experiment}-{period}-{run}-cal-" + f"{intier}.filelist", ), + pulser_files=get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids"), ecal_file=get_pattern_pars_tmp_channel(setup, "pht", "aoecal"), eres_file=get_pattern_pars_tmp_channel( setup, "pht", "aoecal_objects", extension="pkl" @@ -609,7 +750,7 @@ rule build_pht_lq_calibration: ), plot_file=temp(get_pattern_plts_tmp_channel(setup, "pht")), log: - get_pattern_log_channel(setup, "pars_pht_lq_cal"), + get_pattern_log_channel(setup, "par_pht_lq_cal"), group: "par-pht" resources: @@ -629,7 +770,7 @@ rule build_pht_lq_calibration: "--hit_pars {output.hit_pars} " "--plot_file {output.plot_file} " "--ecal_file {input.ecal_file} " - "--tcm_filelist {input.tcm_filelist} " + "--pulser_files {input.pulser_files} " "--input_files {input.files}" @@ -642,3 +783,95 @@ for key, items in ordered.items(): rule_order_list += [item.name for item in items] rule_order_list.append(fallback_pht_rule.name) workflow._ruleorder.add(*rule_order_list) # [::-1] + + +rule build_pars_pht_objects: + input: + lambda wildcards: read_filelist_pars_cal_channel( + wildcards, + "pht_objects_pkl", + ), + output: + get_pattern_pars( + setup, + "pht", + name="objects", + extension="dir", + check_in_cycle=check_in_cycle, + ), + group: + "merge-hit" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +rule build_plts_pht: + input: + lambda wildcards: read_filelist_plts_cal_channel(wildcards, "pht"), + output: + get_pattern_plts(setup, "pht"), + group: + "merge-hit" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +# rule build_pars_pht: +# input: +# infiles=lambda wildcards: read_filelist_pars_cal_channel(wildcards, "pht"), +# plts=get_pattern_plts(setup, "pht"), +# objects=get_pattern_pars( +# setup, +# "pht", +# name="objects", +# extension="dir", +# check_in_cycle=check_in_cycle, +# ), +# output: +# get_pattern_pars(setup, "pht", check_in_cycle=check_in_cycle), +# group: +# "merge-hit" +# shell: +# "{swenv} python3 -B " +# f"{basedir}/../scripts/merge_channels.py " +# "--input {input.infiles} " +# "--output {output} " + + +rule build_pht: + input: + dsp_file=get_pattern_tier(setup, intier, check_in_cycle=False), + pars_file=lambda wildcards: pars_catalog.get_par_file( + setup, wildcards.timestamp, "pht" + ), + output: + tier_file=get_pattern_tier(setup, "pht", check_in_cycle=check_in_cycle), + db_file=get_pattern_pars_tmp(setup, "pht_db"), + params: + timestamp="{timestamp}", + datatype="{datatype}", + tier="pht", + log: + get_pattern_log(setup, "tier_pht"), + group: + "tier-pht" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/build_hit.py')} " + "--configs {configs} " + "--log {log} " + "--tier {params.tier} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--pars_file {input.pars_file} " + "--output {output.tier_file} " + "--input {input.dsp_file} " + "--db_file {output.db_file}" diff --git a/rules/psp.smk b/rules/psp.smk new file mode 100644 index 0000000..d581107 --- /dev/null +++ b/rules/psp.smk @@ -0,0 +1,350 @@ +""" +Snakemake rules for processing pht (partition hit) tier data. This is done in 4 steps: +- extraction of calibration curves(s) for each run for each channel from cal data +- extraction of psd calibration parameters and partition level energy fitting for each channel over whole partition from cal data +- combining of all channels into single pars files with associated plot and results files +- running build hit over all channels using par file +""" + +from scripts.util.pars_loading import pars_catalog +from scripts.util.create_pars_keylist import pars_key_resolve +from scripts.util.utils import par_psp_path, par_dsp_path, set_last_rule_name +from scripts.util.patterns import ( + get_pattern_pars_tmp_channel, + get_pattern_plts_tmp_channel, + get_pattern_log_channel, + get_pattern_plts, + get_pattern_tier, + get_pattern_pars_tmp, + get_pattern_log, + get_pattern_pars, +) + +pars_key_resolve.write_par_catalog( + ["-*-*-*-cal"], + os.path.join(pars_path(setup), "dsp", "validity.jsonl"), + get_pattern_tier_raw(setup), + {"cal": ["par_dsp"], "lar": ["par_dsp"]}, +) + +pars_key_resolve.write_par_catalog( + ["-*-*-*-cal"], + os.path.join(pars_path(setup), "psp", "validity.jsonl"), + get_pattern_tier_raw(setup), + {"cal": ["par_psp"], "lar": ["par_psp"]}, +) + +psp_rules = {} +for key, dataset in part.datasets.items(): + for partition in dataset.keys(): + + rule: + input: + dsp_pars=part.get_par_files( + f"{par_dsp_path(setup)}/validity.jsonl", + partition, + key, + tier="dsp", + name="eopt", + ), + dsp_objs=part.get_par_files( + f"{par_dsp_path(setup)}/validity.jsonl", + partition, + key, + tier="dsp", + name="objects", + extension="pkl", + ), + dsp_plots=part.get_plt_files( + f"{par_dsp_path(setup)}/validity.jsonl", partition, key, tier="dsp" + ), + wildcard_constraints: + channel=part.get_wildcard_constraints(partition, key), + params: + datatype="cal", + channel="{channel}" if key == "default" else key, + timestamp=part.get_timestamp( + f"{par_psp_path(setup)}/validity.jsonl", partition, key, tier="psp" + ), + output: + psp_pars=temp( + part.get_par_files( + f"{par_psp_path(setup)}/validity.jsonl", + partition, + key, + tier="psp", + name="eopt", + ) + ), + psp_objs=temp( + part.get_par_files( + f"{par_psp_path(setup)}/validity.jsonl", + partition, + key, + tier="psp", + name="objects", + extension="pkl", + ) + ), + psp_plots=temp( + part.get_plt_files( + f"{par_psp_path(setup)}/validity.jsonl", + partition, + key, + tier="psp", + ) + ), + log: + part.get_log_file( + f"{par_psp_path(setup)}/validity.jsonl", + partition, + key, + "psp", + name="par_psp", + ), + group: + "par-psp" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/par_psp.py " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--in_plots {input.dsp_plots} " + "--out_plots {output.psp_plots} " + "--in_obj {input.dsp_objs} " + "--out_obj {output.psp_objs} " + "--input {input.dsp_pars} " + "--output {output.psp_pars} " + + set_last_rule_name(workflow, f"{key}-{partition}-build_par_psp") + + if key in psp_rules: + psp_rules[key].append(list(workflow.rules)[-1]) + else: + psp_rules[key] = [list(workflow.rules)[-1]] + + +# Merged energy and a/e supercalibrations to reduce number of rules as they have same inputs/outputs +# This rule builds the a/e calibration using the calibration dsp files for the whole partition +rule build_par_psp: + input: + dsp_pars=get_pattern_pars_tmp_channel(setup, "dsp", "eopt"), + dsp_objs=get_pattern_pars_tmp_channel(setup, "dsp", "objects", extension="pkl"), + dsp_plots=get_pattern_plts_tmp_channel(setup, "dsp"), + params: + datatype="cal", + channel="{channel}", + timestamp="{timestamp}", + output: + psp_pars=temp(get_pattern_pars_tmp_channel(setup, "psp", "eopt")), + psp_objs=temp( + get_pattern_pars_tmp_channel(setup, "psp", "objects", extension="pkl") + ), + psp_plots=temp(get_pattern_plts_tmp_channel(setup, "psp")), + log: + get_pattern_log_channel(setup, "pars_psp"), + group: + "par-psp" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/par_psp.py " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--in_plots {input.dsp_plots} " + "--out_plots {output.psp_plots} " + "--in_obj {input.dsp_objs} " + "--out_obj {output.psp_objs} " + "--input {input.dsp_pars} " + "--output {output.psp_pars} " + + +fallback_psp_rule = list(workflow.rules)[-1] +rule_order_list = [] +ordered = OrderedDict(psp_rules) +ordered.move_to_end("default") +for key, items in ordered.items(): + rule_order_list += [item.name for item in items] +rule_order_list.append(fallback_psp_rule.name) +workflow._ruleorder.add(*rule_order_list) # [::-1] + + +rule build_svm_psp: + input: + hyperpars=lambda wildcards: get_svm_file(wildcards, "psp", "svm_hyperpars"), + train_data=lambda wildcards: get_svm_file( + wildcards, "psp", "svm_hyperpars" + ).replace("hyperpars.json", "train.lh5"), + output: + dsp_pars=get_pattern_pars(setup, "psp", "svm", "pkl"), + log: + get_pattern_log(setup, "pars_psp_svm").replace("{datatype}", "cal"), + group: + "par-dsp-svm" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_dsp_build_svm.py')} " + "--log {log} " + "--train_data {input.train_data} " + "--train_hyperpars {input.hyperpars} " + "--output_file {output.dsp_pars}" + + +rule build_pars_psp_svm: + input: + dsp_pars=get_pattern_pars_tmp_channel(setup, "psp_eopt"), + svm_model=get_pattern_pars(setup, "psp", "svm", "pkl"), + output: + dsp_pars=temp(get_pattern_pars_tmp_channel(setup, "psp")), + log: + get_pattern_log_channel(setup, "pars_dsp_svm"), + group: + "par-dsp" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_dsp_svm.py')} " + "--log {log} " + "--input_file {input.dsp_pars} " + "--output_file {output.dsp_pars} " + "--svm_file {input.svm_model}" + + +rule build_pars_psp_objects: + input: + lambda wildcards: read_filelist_pars_cal_channel( + wildcards, + "psp_objects_pkl", + ), + output: + get_pattern_pars( + setup, + "psp", + name="objects", + extension="dir", + check_in_cycle=check_in_cycle, + ), + group: + "merge-psp" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +rule build_plts_psp: + input: + lambda wildcards: read_filelist_plts_cal_channel(wildcards, "psp"), + output: + get_pattern_plts(setup, "psp"), + group: + "merge-psp" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +rule build_pars_psp_db: + input: + lambda wildcards: read_filelist_pars_cal_channel(wildcards, "psp"), + output: + temp( + get_pattern_pars_tmp( + setup, + "psp", + datatype="cal", + ) + ), + group: + "merge-psp" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +rule build_pars_psp: + input: + in_files=lambda wildcards: read_filelist_pars_cal_channel( + wildcards, "dsp_dplms_lh5" + ), + in_db=get_pattern_pars_tmp( + setup, + "psp", + datatype="cal", + ), + plts=get_pattern_plts(setup, "psp"), + objects=get_pattern_pars( + setup, + "psp", + name="objects", + extension="dir", + check_in_cycle=check_in_cycle, + ), + output: + out_file=get_pattern_pars( + setup, + "psp", + extension="lh5", + check_in_cycle=check_in_cycle, + ), + out_db=get_pattern_pars(setup, "psp", check_in_cycle=check_in_cycle), + group: + "merge-psp" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--output {output.out_file} " + "--in_db {input.in_db} " + "--out_db {output.out_db} " + "--input {input.in_files} " + + +rule build_psp: + input: + raw_file=get_pattern_tier_raw(setup), + pars_file=ancient( + lambda wildcards: pars_catalog.get_par_file( + setup, wildcards.timestamp, "psp" + ) + ), + params: + timestamp="{timestamp}", + datatype="{datatype}", + output: + tier_file=get_pattern_tier(setup, "psp", check_in_cycle=check_in_cycle), + db_file=get_pattern_pars_tmp(setup, "psp_db"), + log: + get_pattern_log(setup, "tier_psp"), + group: + "tier-dsp" + resources: + runtime=300, + mem_swap=25, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/build_dsp.py')} " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--input {input.raw_file} " + "--output {output.tier_file} " + "--db_file {output.db_file} " + "--pars_file {input.pars_file} " diff --git a/rules/qc_phy.smk b/rules/qc_phy.smk new file mode 100644 index 0000000..6cb1272 --- /dev/null +++ b/rules/qc_phy.smk @@ -0,0 +1,160 @@ +from scripts.util.pars_loading import pars_catalog +from scripts.util.create_pars_keylist import pars_key_resolve +from scripts.util.utils import filelist_path, par_pht_path, set_last_rule_name +from scripts.util.patterns import ( + get_pattern_pars_tmp_channel, + get_pattern_plts_tmp_channel, + get_pattern_log_channel, + get_pattern_par_pht, + get_pattern_plts, + get_pattern_tier, + get_pattern_pars_tmp, + get_pattern_log, + get_pattern_pars, +) + +intier = "psp" + + +qc_pht_rules = {} +for key, dataset in part.datasets.items(): + for partition in dataset.keys(): + + rule: + input: + phy_files=part.get_filelists(partition, key, intier, datatype="phy"), + wildcard_constraints: + channel=part.get_wildcard_constraints(partition, key), + params: + datatype="cal", + channel="{channel}" if key == "default" else key, + timestamp=part.get_timestamp( + f"{par_pht_path(setup)}/validity.jsonl", partition, key, tier="pht" + ), + output: + hit_pars=[ + temp(file) + for file in part.get_par_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="qcphy", + ) + ], + plot_file=[ + temp(file) + for file in part.get_plt_files( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + tier="pht", + name="qcphy", + ) + ], + log: + part.get_log_file( + f"{par_pht_path(setup)}/validity.jsonl", + partition, + key, + "pht", + name="par_pht_qc_phy", + ), + group: + "par-pht" + resources: + mem_swap=len(part.get_filelists(partition, key, intier)) * 20, + runtime=300, + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/pars_pht_qc_phy.py " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--save_path {output.hit_pars} " + "--plot_path {output.plot_file} " + "--phy_files {input.phy_files}" + + set_last_rule_name(workflow, f"{key}-{partition}-build_pht_qc_phy") + + if key in qc_pht_rules: + qc_pht_rules[key].append(list(workflow.rules)[-1]) + else: + qc_pht_rules[key] = [list(workflow.rules)[-1]] + + +# Merged energy and a/e supercalibrations to reduce number of rules as they have same inputs/outputs +# This rule builds the a/e calibration using the calibration dsp files for the whole partition +rule build_pht_qc_phy: + input: + phy_files=os.path.join( + filelist_path(setup), + "all-{experiment}-{period}-{run}-phy-" + f"{intier}.filelist", + ), + params: + datatype="cal", + channel="{channel}", + timestamp="{timestamp}", + output: + hit_pars=temp(get_pattern_pars_tmp_channel(setup, "pht", "qcphy")), + plot_file=temp(get_pattern_plts_tmp_channel(setup, "pht", "qcphy")), + log: + get_pattern_log_channel(setup, "pars_pht_qc_phy"), + group: + "par-pht" + resources: + mem_swap=60, + runtime=300, + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/pars_pht_qc_phy.py " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--save_path {output.hit_pars} " + "--plot_path {output.plot_file} " + "--phy_files {input.phy_files}" + + +fallback_qc_rule = list(workflow.rules)[-1] + +rule_order_list = [] +ordered = OrderedDict(qc_pht_rules) +ordered.move_to_end("default") +for key, items in ordered.items(): + rule_order_list += [item.name for item in items] +rule_order_list.append(fallback_qc_rule.name) +workflow._ruleorder.add(*rule_order_list) # [::-1] + + +rule build_plts_pht_phy: + input: + lambda wildcards: read_filelist_plts_cal_channel(wildcards, "pht_qcphy"), + output: + get_pattern_plts(setup, "pht", "qc_phy"), + group: + "merge-hit" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input} " + "--output {output} " + + +rule build_pars_pht_phy: + input: + infiles=lambda wildcards: read_filelist_pars_cal_channel(wildcards, "pht_qcphy"), + plts=get_pattern_plts(setup, "pht", "qc_phy"), + output: + get_pattern_pars(setup, "pht", name="qc_phy", check_in_cycle=check_in_cycle), + group: + "merge-hit" + shell: + "{swenv} python3 -B " + f"{basedir}/../scripts/merge_channels.py " + "--input {input.infiles} " + "--output {output} " diff --git a/rules/skm.smk b/rules/skm.smk index c4356fa..3c9a619 100644 --- a/rules/skm.smk +++ b/rules/skm.smk @@ -6,28 +6,20 @@ from scripts.util.patterns import ( get_pattern_tier, get_pattern_log, get_pattern_pars, + get_pattern_log_concat, ) rule build_skm: input: - dsp_files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-phy-dsp.filelist" - ), - hit_files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-phy-pht.filelist" - ), - tcm_files=os.path.join( - filelist_path(setup), "all-{experiment}-{period}-{run}-phy-tcm.filelist" - ), - evt_files=lambda wildcards: read_filelist_phy(wildcards, "pet"), + evt_file=get_pattern_tier(setup, "pet_concat", check_in_cycle=False), output: skm_file=get_pattern_tier(setup, "skm", check_in_cycle=check_in_cycle), params: - timestamp="{timestamp}", - datatype="{datatype}", + timestamp="all", + datatype="phy", log: - get_pattern_log(setup, "tier_skm"), + get_pattern_log_concat(setup, "tier_skm"), group: "tier-skm" resources: @@ -39,9 +31,5 @@ rule build_skm: "--metadata {meta} " "--log {log} " "--datatype {params.datatype} " - "--timestamp {params.timestamp} " - "--hit_files {input.hit_files} " - "--tcm_files {input.tcm_files} " - "--dsp_files {input.dsp_files} " - "--evt_files {input.evt_files} " + "--evt_file {input.evt_file} " "--output {output.skm_file} " diff --git a/rules/tcm.smk b/rules/tcm.smk index 380c882..cfdf72c 100644 --- a/rules/tcm.smk +++ b/rules/tcm.smk @@ -6,6 +6,8 @@ from scripts.util.patterns import ( get_pattern_tier_raw, get_pattern_tier, get_pattern_log, + get_pattern_pars_tmp_channel, + get_pattern_log_channel, ) @@ -33,3 +35,31 @@ rule build_tier_tcm: "--timestamp {params.timestamp} " "{input} " "{output}" + + +# This rule builds the tcm files each raw file +rule build_pulser_ids: + input: + tcm_files=lambda wildcards: read_filelist_cal(wildcards, "tcm"), + params: + timestamp="{timestamp}", + datatype="cal", + channel="{channel}", + output: + pulser=temp(get_pattern_pars_tmp_channel(setup, "tcm", "pulser_ids")), + log: + get_pattern_log_channel(setup, "tcm_pulsers"), + group: + "tier-tcm" + resources: + runtime=300, + shell: + "{swenv} python3 -B " + f"{workflow.source_path('../scripts/pars_tcm_pulser.py')} " + "--log {log} " + "--configs {configs} " + "--datatype {params.datatype} " + "--timestamp {params.timestamp} " + "--channel {params.channel} " + "--tcm_files {input.tcm_files} " + "--pulser_file {output.pulser} " diff --git a/scripts/build_dsp.py b/scripts/build_dsp.py index 9906782..2fd2248 100644 --- a/scripts/build_dsp.py +++ b/scripts/build_dsp.py @@ -18,6 +18,18 @@ from legendmeta import LegendMetadata from legendmeta.catalog import Props + +def replace_list_with_array(dic): + for key, value in dic.items(): + if isinstance(value, dict): + dic[key] = replace_list_with_array(value) + elif isinstance(value, list): + dic[key] = np.array(value, dtype="float32") + else: + pass + return dic + + warnings.filterwarnings(action="ignore", category=RuntimeWarning) argparser = argparse.ArgumentParser() @@ -43,20 +55,14 @@ "inputs" ]["processing_chain"] -database_dic = Props.read_from(args.pars_file) - - -def replace_list_with_array(dic): - for key, value in dic.items(): - if isinstance(value, dict): - dic[key] = replace_list_with_array(value) - elif isinstance(value, list): - dic[key] = np.array(value) - else: - pass - return dic - +channel_dict = {chan: Props.read_from(file) for chan, file in channel_dict.items()} +db_files = [ + par_file + for par_file in args.pars_file + if os.path.splitext(par_file)[1] == ".json" or os.path.splitext(par_file)[1] == ".yml" +] +database_dic = Props.read_from(db_files, subst_pathvar=True) database_dic = replace_list_with_array(database_dic) pathlib.Path(os.path.dirname(args.output)).mkdir(parents=True, exist_ok=True) @@ -88,8 +94,8 @@ def replace_list_with_array(dic): outputs = {} channels = [] -for channel, file in channel_dict.items(): - output = Props.read_from(file)["outputs"] +for channel, chan_dict in channel_dict.items(): + output = chan_dict["outputs"] in_dict = False for entry in outputs: if outputs[entry]["fields"] == output: diff --git a/scripts/build_evt.py b/scripts/build_evt.py index e5febca..bba8084 100644 --- a/scripts/build_evt.py +++ b/scripts/build_evt.py @@ -9,23 +9,26 @@ import numpy as np from legendmeta import LegendMetadata from legendmeta.catalog import Props -from lgdo.types import Table -from pygama.evt.build_evt import build_evt +from lgdo.types import Array +from pygama.evt import build_evt sto = lh5.LH5Store() -def replace_evt_with_key(dic, new_key): - for key, d in dic.items(): - if isinstance(d, dict): - dic[key] = replace_evt_with_key(d, new_key) - elif isinstance(d, list): - dic[key] = [item.replace("evt", new_key) for item in d] - elif isinstance(d, str): - dic[key] = d.replace("evt", new_key) - else: - pass - return dic +def find_matching_values_with_delay(arr1, arr2, jit_delay): + matching_values = [] + + # Create an array with all possible delay values + delays = np.arange(0, int(1e9 * jit_delay)) * jit_delay + + for delay in delays: + arr2_delayed = arr2 + delay + + # Find matching values and indices + mask = np.isin(arr1, arr2_delayed, assume_unique=True) + matching_values.extend(arr1[mask]) + + return np.unique(matching_values) argparser = argparse.ArgumentParser() @@ -45,8 +48,13 @@ def replace_evt_with_key(dic, new_key): argparser.add_argument("--output", help="output file", type=str) args = argparser.parse_args() -pathlib.Path(os.path.dirname(args.log)).mkdir(parents=True, exist_ok=True) -logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") +if args.log is not None: + pathlib.Path(os.path.dirname(args.log)).mkdir(parents=True, exist_ok=True) + logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") +else: + logging.basicConfig(level=logging.DEBUG) + +logging.getLogger("legendmeta").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO) logging.getLogger("parse").setLevel(logging.INFO) logging.getLogger("lgdo").setLevel(logging.INFO) @@ -58,9 +66,10 @@ def replace_evt_with_key(dic, new_key): # load in config configs = LegendMetadata(path=args.configs) if args.tier == "evt" or args.tier == "pet": - evt_config_file = configs.on(args.timestamp, system=args.datatype)["snakemake_rules"][ - "tier_evt" - ]["inputs"]["evt_config"] + config_dict = configs.on(args.timestamp, system=args.datatype)["snakemake_rules"]["tier_evt"][ + "inputs" + ] + evt_config_file = config_dict["evt_config"] else: msg = "unknown tier" raise ValueError(msg) @@ -68,31 +77,48 @@ def replace_evt_with_key(dic, new_key): meta = LegendMetadata(path=args.metadata) chmap = meta.channelmap(args.timestamp) -if isinstance(evt_config_file, dict): - evt_config = {} - for key, _evt_config in evt_config_file.items(): - if _evt_config is not None: - _evt_config = Props.read_from(_evt_config) - # block for snakemake to fill in channel lists - for field, dic in _evt_config["channels"].items(): - if isinstance(dic, dict): - chans = chmap.map("system", unique=False)[dic["system"]] - if "selectors" in dic: - try: - for k, val in dic["selectors"].items(): - chans = chans.map(k, unique=False)[val] - except KeyError: - chans = None - if chans is not None: - chans = [f"ch{chan}" for chan in list(chans.map("daq.rawid"))] - else: - chans = [] - _evt_config["channels"][field] = chans - evt_config[key] = replace_evt_with_key(_evt_config, f"evt/{key}") -else: - evt_config = {"all": Props.read_from(evt_config_file)} +evt_config = Props.read_from(evt_config_file) + +# block for snakemake to fill in channel lists +for field, dic in evt_config["channels"].items(): + if isinstance(dic, dict): + chans = chmap.map("system", unique=False)[dic["system"]] + if "selectors" in dic: + try: + for k, val in dic["selectors"].items(): + chans = chans.map(k, unique=False)[val] + except KeyError: + chans = None + if chans is not None: + chans = [f"ch{chan}" for chan in list(chans.map("daq.rawid"))] + else: + chans = [] + evt_config["channels"][field] = chans + +log.debug(json.dumps(evt_config["channels"], indent=2)) + +t_start = time.time() +pathlib.Path(os.path.dirname(args.output)).mkdir(parents=True, exist_ok=True) + +rng = np.random.default_rng() +rand_num = f"{rng.integers(0,99999):05d}" +temp_output = f"{args.output}.{rand_num}" + +table = build_evt( + { + "tcm": (args.tcm_file, "hardware_tcm_1", "ch{}"), + "dsp": (args.dsp_file, "dsp", "ch{}"), + "hit": (args.hit_file, "hit", "ch{}"), + "evt": (None, "evt"), + }, + evt_config, +) + +if "muon_config" in config_dict and config_dict["muon_config"] is not None: + muon_config = Props.read_from(config_dict["muon_config"]["evt_config"]) + field_config = Props.read_from(config_dict["muon_config"]["field_config"]) # block for snakemake to fill in channel lists - for field, dic in evt_config["channels"].items(): + for field, dic in muon_config["channels"].items(): if isinstance(dic, dict): chans = chmap.map("system", unique=False)[dic["system"]] if "selectors" in dic: @@ -105,34 +131,38 @@ def replace_evt_with_key(dic, new_key): chans = [f"ch{chan}" for chan in list(chans.map("daq.rawid"))] else: chans = [] - evt_config["channels"][field] = chans - -log.debug(json.dumps(evt_config, indent=2)) - -t_start = time.time() -pathlib.Path(os.path.dirname(args.output)).mkdir(parents=True, exist_ok=True) - -rng = np.random.default_rng() -rand_num = f"{rng.integers(0,99999):05d}" -temp_output = f"{args.output}.{rand_num}" - -tables = {} -for key, config in evt_config.items(): - tables[key] = build_evt( - f_tcm=args.tcm_file, - f_dsp=args.dsp_file, - f_hit=args.hit_file, - f_evt=None, - evt_config=config, - evt_group=f"evt/{key}" if key != "all" else "evt", - tcm_group="hardware_tcm_1", - dsp_group="dsp", - hit_group="hit", - tcm_id_table_pattern="ch{}", + muon_config["channels"][field] = chans + + trigger_timestamp = table[field_config["ged_timestamp"]["table"]][ + field_config["ged_timestamp"]["field"] + ].nda + if "hardware_tcm_2" in lh5.ls(args.tcm_file): + muon_table = build_evt( + { + "tcm": (args.tcm_file, "hardware_tcm_2", "ch{}"), + "dsp": (args.dsp_file, "dsp", "ch{}"), + "hit": (args.hit_file, "hit", "ch{}"), + "evt": (None, "evt"), + }, + muon_config, + ) + + muon_timestamp = muon_table[field_config["muon_timestamp"]["field"]].nda + muon_tbl_flag = muon_table[field_config["muon_flag"]["field"]].nda + if len(muon_timestamp[muon_tbl_flag]) > 0: + is_muon_veto_triggered = find_matching_values_with_delay( + trigger_timestamp, muon_timestamp[muon_tbl_flag], field_config["jitter"] + ) + muon_flag = np.isin(trigger_timestamp, is_muon_veto_triggered) + else: + muon_flag = np.zeros(len(trigger_timestamp), dtype=bool) + else: + muon_flag = np.zeros(len(trigger_timestamp), dtype=bool) + table[field_config["output_field"]["table"]].add_column( + field_config["output_field"]["field"], Array(muon_flag) ) -tbl = Table(col_dict=tables) -sto.write(obj=tbl, name="evt", lh5_file=temp_output, wo_mode="a") +sto.write(obj=table, name="evt", lh5_file=temp_output, wo_mode="a") os.rename(temp_output, args.output) t_elap = time.time() - t_start diff --git a/scripts/check_blinding.py b/scripts/check_blinding.py index 550f5a8..4829608 100644 --- a/scripts/check_blinding.py +++ b/scripts/check_blinding.py @@ -71,7 +71,7 @@ # bin with 1 keV bins and get maxs hist, bins, var = get_hist(daqenergy_cal, np.arange(0, 3000, 1)) -maxs = get_i_local_maxima(hist, delta=5) +maxs = get_i_local_maxima(hist, delta=25) log.info(f"peaks found at : {maxs}") # plot the energy spectrum to check calibration diff --git a/scripts/complete_run.py b/scripts/complete_run.py index b266d50..722b244 100644 --- a/scripts/complete_run.py +++ b/scripts/complete_run.py @@ -148,7 +148,9 @@ def build_valid_keys(input_files, output_dir): with open(out_file, "w") as w: w.write(out_string) - os.system(f"rm {input_files}") + for input_file in input_files: + if os.path.isfile(input_file): + os.remove(input_file) def build_file_dbs(input_files, output_dir): @@ -165,6 +167,13 @@ def build_file_dbs(input_files, output_dir): setup = snakemake.params.setup basedir = snakemake.params.basedir +check_log_files( + snakemake.params.log_path, + snakemake.output.summary_log, + snakemake.output.gen_output, + warning_file=snakemake.output.warning_log, +) + if os.getenv("PRODENV") in snakemake.params.filedb_path: file_db_config = { "data_dir": "$PRODENV", @@ -188,7 +197,7 @@ def build_file_dbs(input_files, output_dir): ut.tier_hit_path(setup), "" ), "pht": pat.get_pattern_tier(setup, "pht", check_in_cycle=False).replace( - ut.tier_hit_path(setup), "" + ut.tier_pht_path(setup), "" ), "evt": pat.get_pattern_tier(setup, "evt", check_in_cycle=False).replace( ut.tier_evt_path(setup), "" @@ -233,7 +242,7 @@ def build_file_dbs(input_files, output_dir): ut.tier_hit_path(setup), "" ), "pht": pat.get_pattern_tier(setup, "pht", check_in_cycle=False).replace( - ut.tier_hit_path(setup), "" + ut.tier_pht_path(setup), "" ), "evt": pat.get_pattern_tier(setup, "evt", check_in_cycle=False).replace( ut.tier_evt_path(setup), "" @@ -256,20 +265,13 @@ def build_file_dbs(input_files, output_dir): }, } -check_log_files( - snakemake.params.log_path, - snakemake.output.summary_log, - snakemake.output.gen_output, - warning_file=snakemake.output.warning_log, -) - if snakemake.wildcards.tier != "daq": os.makedirs(snakemake.params.filedb_path, exist_ok=True) with open(os.path.join(snakemake.params.filedb_path, "file_db_config.json"), "w") as w: json.dump(file_db_config, w, indent=2) build_file_dbs(snakemake.params.tmp_par_path, snakemake.params.filedb_path) - os.system(f"rm {os.path.join(snakemake.params.filedb_path, 'file_db_config.json')}") + os.remove(os.path.join(snakemake.params.filedb_path, "file_db_config.json")) build_valid_keys(snakemake.params.tmp_par_path, snakemake.params.valid_keys_path) diff --git a/scripts/create_filelist.py b/scripts/create_filelist.py index 8900343..a40b77c 100644 --- a/scripts/create_filelist.py +++ b/scripts/create_filelist.py @@ -57,6 +57,10 @@ other_filenames = [] if tier == "blind": fn_pattern = get_pattern_tier(setup, "raw", check_in_cycle=False) +elif tier == "skm" or tier == "pet_concat": + fn_pattern = get_pattern_tier(setup, "pet", check_in_cycle=False) +elif tier == "evt_concat": + fn_pattern = get_pattern_tier(setup, "evt", check_in_cycle=False) else: fn_pattern = get_pattern_tier(setup, tier, check_in_cycle=False) @@ -70,7 +74,7 @@ else: if tier == "blind" and _key.datatype == "phy": filename = FileKey.get_path_from_filekey(_key, get_pattern_tier_raw_blind(setup)) - elif tier == "skm" and _key.datatype != "phy": + elif tier == "skm": # and _key.datatype != "phy" filename = FileKey.get_path_from_filekey( _key, get_pattern_tier(setup, "pet", check_in_cycle=False) ) @@ -101,17 +105,16 @@ phy_filenames = sorted(phy_filenames) other_filenames = sorted(other_filenames) -if tier == "skm": +if tier == "skm" or tier == "pet_concat" or tier == "evt_concat": sorted_phy_filenames = run_grouper(phy_filenames) phy_filenames = [] for run in sorted_phy_filenames: - run_files = sorted( - run, - key=lambda filename: FileKey.get_filekey_from_pattern( - filename, fn_pattern - ).get_unix_timestamp(), - ) - phy_filenames.append(run_files[0]) + key = FileKey.get_filekey_from_pattern(run[0], fn_pattern) + out_key = FileKey.get_path_from_filekey( + key, get_pattern_tier(setup, tier, check_in_cycle=False) + )[0] + + phy_filenames.append(out_key) filenames = phy_filenames + other_filenames diff --git a/scripts/merge_channels.py b/scripts/merge_channels.py index 1d43e6f..b169d29 100644 --- a/scripts/merge_channels.py +++ b/scripts/merge_channels.py @@ -5,84 +5,135 @@ import pickle as pkl import shelve +import lgdo.lh5 as lh5 +import numpy as np +from legendmeta.catalog import Props +from util.FileKey import ChannelProcKey + + +def replace_path(d, old_path, new_path): + if isinstance(d, dict): + for k, v in d.items(): + d[k] = replace_path(v, old_path, new_path) + elif isinstance(d, list): + for i in range(len(d)): + d[i] = replace_path(d[i], old_path, new_path) + elif isinstance(d, str) and old_path in d: + d = d.replace(old_path, new_path) + d = d.replace(new_path, f"$_/{os.path.basename(new_path)}") + return d + + argparser = argparse.ArgumentParser() -argparser.add_argument("--input", help="input file", nargs="*", type=str) -argparser.add_argument("--output", help="output file", nargs="*", type=str) +argparser.add_argument("--input", help="input file", nargs="*", type=str, required=True) +argparser.add_argument("--output", help="output file", type=str, required=True) +argparser.add_argument( + "--in_db", + help="in db file (used for when lh5 files referred to in db)", + type=str, + required=False, +) +argparser.add_argument( + "--out_db", + help="lh5 file (used for when lh5 files referred to in db)", + type=str, + required=False, +) args = argparser.parse_args() +# change to only have 1 output file for multiple inputs +# don't care about processing step, check if extension matches + channel_files = args.input -for _i, out_file in enumerate(args.output): - file_extension = pathlib.Path(out_file).suffix - processing_step = os.path.splitext(out_file)[0].split("-")[-1] - if file_extension == ".json": - out_dict = {} - for channel in channel_files: - if os.path.splitext(channel)[0].split("-")[-1] == processing_step: - with open(channel) as r: - channel_dict = json.load(r) - ( - experiment, - period, - run, - datatype, - timestamp, - channel_name, - name, - ) = os.path.basename(channel).split("-") - out_dict[channel_name] = channel_dict - else: - pass - - pathlib.Path(os.path.dirname(out_file)).mkdir(parents=True, exist_ok=True) - with open(out_file, "w") as w: - json.dump(out_dict, w, indent=4) - - elif file_extension == ".pkl": - out_dict = {} + +file_extension = pathlib.Path(args.output).suffix + +if file_extension == ".dat" or file_extension == ".dir": + out_file = os.path.splitext(args.output)[0] +else: + out_file = args.output + +rng = np.random.default_rng() +rand_num = f"{rng.integers(0,99999):05d}" +temp_output = f"{out_file}.{rand_num}" + +pathlib.Path(os.path.dirname(args.output)).mkdir(parents=True, exist_ok=True) + + +if file_extension == ".json": + out_dict = {} + for channel in channel_files: + if pathlib.Path(channel).suffix == file_extension: + channel_dict = Props.read_from(channel) + + fkey = ChannelProcKey.get_filekey_from_pattern(os.path.basename(channel)) + channel_name = fkey.channel + out_dict[channel_name] = channel_dict + else: + msg = "Output file extension does not match input file extension" + raise RuntimeError(msg) + + with open(temp_output, "w") as w: + json.dump(out_dict, w, indent=4) + + os.rename(temp_output, out_file) + +elif file_extension == ".pkl": + out_dict = {} + for channel in channel_files: + with open(channel, "rb") as r: + channel_dict = pkl.load(r) + fkey = ChannelProcKey.get_filekey_from_pattern(os.path.basename(channel)) + channel_name = fkey.channel + out_dict[channel_name] = channel_dict + + with open(temp_output, "wb") as w: + pkl.dump(out_dict, w, protocol=pkl.HIGHEST_PROTOCOL) + + os.rename(temp_output, out_file) + +elif file_extension == ".dat" or file_extension == ".dir": + common_dict = {} + with shelve.open(out_file, "c", protocol=pkl.HIGHEST_PROTOCOL) as shelf: for channel in channel_files: - if os.path.splitext(channel)[0].split("-")[-1] == processing_step: - with open(channel, "rb") as r: - channel_dict = pkl.load(r) - ( - experiment, - period, - run, - datatype, - timestamp, - channel_name, - name, - ) = os.path.basename(channel).split("-") - out_dict[channel_name] = channel_dict - else: - pass - pathlib.Path(os.path.dirname(out_file)).mkdir(parents=True, exist_ok=True) - with open(out_file, "wb") as w: - pkl.dump(out_dict, w, protocol=pkl.HIGHEST_PROTOCOL) - - elif file_extension == ".dat" or file_extension == ".dir": - _out_file = os.path.splitext(out_file)[0] - pathlib.Path(os.path.dirname(_out_file)).mkdir(parents=True, exist_ok=True) - common_dict = {} - with shelve.open(_out_file, "c", protocol=pkl.HIGHEST_PROTOCOL) as shelf: - for channel in channel_files: - if os.path.splitext(channel)[0].split("-")[-1] == processing_step: - with open(channel, "rb") as r: - channel_dict = pkl.load(r) - ( - experiment, - period, - run, - datatype, - timestamp, - channel_name, - name, - ) = os.path.basename(channel).split("-") - if isinstance(channel_dict, dict) and "common" in list(channel_dict): - chan_common_dict = channel_dict.pop("common") - common_dict[channel_name] = chan_common_dict - shelf[channel_name] = channel_dict - else: - pass - if len(common_dict) > 0: - shelf["common"] = common_dict + with open(channel, "rb") as r: + channel_dict = pkl.load(r) + fkey = ChannelProcKey.get_filekey_from_pattern(os.path.basename(channel)) + channel_name = fkey.channel + if isinstance(channel_dict, dict) and "common" in list(channel_dict): + chan_common_dict = channel_dict.pop("common") + common_dict[channel_name] = chan_common_dict + shelf[channel_name] = channel_dict + if len(common_dict) > 0: + shelf["common"] = common_dict + + +elif file_extension == ".lh5": + sto = lh5.LH5Store() + + if args.in_db: + db_dict = Props.read_from(args.in_db) + for channel in channel_files: + if pathlib.Path(channel).suffix == file_extension: + fkey = ChannelProcKey.get_filekey_from_pattern(os.path.basename(channel)) + channel_name = fkey.channel + + tb_in = sto.read(f"{channel_name}", channel)[0] + + sto.write( + tb_in, + name=channel_name, + lh5_file=temp_output, + wo_mode="a", + ) + if args.in_db: + db_dict[channel_name] = replace_path(db_dict[channel_name], channel, args.output) + else: + msg = "Output file extension does not match input file extension" + raise RuntimeError(msg) + if args.out_db: + with open(args.out_db, "w") as w: + json.dump(db_dict, w, indent=4) + + os.rename(temp_output, out_file) diff --git a/scripts/par_psp.py b/scripts/par_psp.py new file mode 100644 index 0000000..03bfeaf --- /dev/null +++ b/scripts/par_psp.py @@ -0,0 +1,148 @@ +import argparse +import json +import os +import pickle as pkl +from datetime import datetime + +import matplotlib as mpl +import matplotlib.dates as mdates +import matplotlib.pyplot as plt +import numpy as np +from legendmeta import LegendMetadata +from legendmeta.catalog import Props +from util.FileKey import ChannelProcKey + +mpl.use("Agg") + + +argparser = argparse.ArgumentParser() +argparser.add_argument("--input", help="input files", nargs="*", type=str, required=True) +argparser.add_argument("--output", help="output file", nargs="*", type=str, required=True) +argparser.add_argument("--in_plots", help="input plot files", nargs="*", type=str, required=False) +argparser.add_argument( + "--out_plots", help="output plot files", nargs="*", type=str, required=False +) +argparser.add_argument("--in_obj", help="input object files", nargs="*", type=str, required=False) +argparser.add_argument( + "--out_obj", help="output object files", nargs="*", type=str, required=False +) + +argparser.add_argument("--log", help="log_file", type=str) +argparser.add_argument("--configs", help="configs", type=str, required=True) + +argparser.add_argument("--datatype", help="Datatype", type=str, required=True) +argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True) +argparser.add_argument("--channel", help="Channel", type=str, required=True) +args = argparser.parse_args() + +conf = LegendMetadata(path=args.configs) +configs = conf.on(args.timestamp, system=args.datatype) +merge_config = Props.read_from( + configs["snakemake_rules"]["pars_psp"]["inputs"]["psp_config"][args.channel] +) + +ave_fields = merge_config["average_fields"] + +# partitions could be different for different channels - do separately for each channel +in_dicts = {} +for file in args.input: + tstamp = ChannelProcKey.get_filekey_from_pattern(os.path.basename(file)).timestamp + in_dicts[tstamp] = Props.read_from(file) + +plot_dict = {} +for field in ave_fields: + keys = field.split(".") + vals = [] + for tstamp in in_dicts: + val = in_dicts[tstamp] + for key in keys: + val = val[key] + vals.append(val) + if "dsp" in in_dicts[tstamp]: + tmp_dict = in_dicts[tstamp]["dsp"] + else: + tmp_dict = {} + in_dicts[tstamp]["dsp"] = tmp_dict + for i, key in enumerate(keys): + if i == len(keys) - 1: + tmp_dict[key] = val + else: + if key in tmp_dict: + tmp_dict = tmp_dict[key] + else: + tmp_dict[key] = {} + tmp_dict = tmp_dict[key] + if isinstance(vals[0], str): + if "*" in vals[0]: + unit = vals[0].split("*")[1] + rounding = len(val.split("*")[0].split(".")[-1]) if "." in vals[0] else 16 + vals = np.array([float(val.split("*")[0]) for val in vals]) + else: + unit = None + rounding = 16 + else: + vals = np.array(vals) + unit = None + rounding = 16 + + mean_val = np.nan if len(vals[~np.isnan(vals)]) == 0 else np.nanmedian(vals) + mean = f"{round(mean_val, rounding)}*{unit}" if unit is not None else mean_val + + for tstamp in in_dicts: + val = in_dicts[tstamp] + for i, key in enumerate(keys): + if i == len(keys) - 1: + val[key] = mean + else: + val = val[key] + + fig = plt.figure() + plt.scatter([datetime.strptime(tstamp, "%Y%m%dT%H%M%SZ") for tstamp in in_dicts], vals) + plt.axhline(y=mean_val, color="r", linestyle="-") + plt.xlabel("time") + if unit is not None: + plt.ylabel(f"value {unit}") + else: + plt.ylabel("value") + plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%d/%m/%y")) + plt.gcf().autofmt_xdate() + plt.title(f"{field}") + plot_dict[field] = fig + plt.close() + +for file in args.output: + tstamp = ChannelProcKey.get_filekey_from_pattern(os.path.basename(file)).timestamp + with open(file, "w") as f: + json.dump(in_dicts[tstamp], f, indent=2) + + +if args.out_plots: + for file in args.out_plots: + tstamp = ChannelProcKey.get_filekey_from_pattern(os.path.basename(file)).timestamp + if args.in_plots: + for infile in args.in_plots: + if tstamp in infile: + with open(infile, "rb") as f: + old_plot_dict = pkl.load(f) + break + old_plot_dict.update({"psp": plot_dict}) + new_plot_dict = old_plot_dict + else: + new_plot_dict = {"psp": plot_dict} + with open(file, "wb") as f: + pkl.dump(new_plot_dict, f, protocol=pkl.HIGHEST_PROTOCOL) + +if args.out_obj: + for file in args.out_obj: + tstamp = ChannelProcKey.get_filekey_from_pattern(os.path.basename(file)).timestamp + if args.in_obj: + for infile in args.in_obj: + if tstamp in infile: + with open(infile, "rb") as f: + old_obj_dict = pkl.load(f) + break + new_obj_dict = old_obj_dict + else: + new_obj_dict = {} + with open(file, "wb") as f: + pkl.dump(new_obj_dict, f, protocol=pkl.HIGHEST_PROTOCOL) diff --git a/scripts/pars_dsp_build_svm.py b/scripts/pars_dsp_build_svm.py new file mode 100644 index 0000000..6a44fec --- /dev/null +++ b/scripts/pars_dsp_build_svm.py @@ -0,0 +1,59 @@ +import argparse +import json +import logging +import os +import pickle as pkl + +os.environ["LGDO_CACHE"] = "false" +os.environ["LGDO_BOUNDSCHECK"] = "false" +os.environ["DSPEED_CACHE"] = "false" +os.environ["DSPEED_BOUNDSCHECK"] = "false" + +import lgdo.lh5 as lh5 +from sklearn.svm import SVC + +argparser = argparse.ArgumentParser() +argparser.add_argument("--log", help="log file", type=str) +argparser.add_argument("--output_file", help="output SVM file", type=str, required=True) +argparser.add_argument("--train_data", help="input data file", type=str, required=True) +argparser.add_argument("--train_hyperpars", help="input hyperparameter file", required=True) +args = argparser.parse_args() + +logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") +logging.getLogger("parse").setLevel(logging.INFO) +logging.getLogger("lgdo").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) + +sto = lh5.LH5Store() +log = logging.getLogger(__name__) + +# Load files +tb, _ = sto.read("ml_train/dsp", args.train_data) +log.debug("loaded data") + +with open(args.train_hyperpars) as hyperpars_file: + hyperpars = json.load(hyperpars_file) + +# Define training inputs +dwts_norm = tb["dwt_norm"].nda +labels = tb["dc_label"].nda + + +log.debug("training model") +# Initialize and train SVM +svm = SVC( + random_state=int(hyperpars["random_state"]), + kernel=hyperpars["kernel"], + decision_function_shape=hyperpars["decision_function_shape"], + class_weight=hyperpars["class_weight"], + C=float(hyperpars["C"]), + gamma=float(hyperpars["gamma"]), +) + +svm.fit(dwts_norm, labels) + +log.debug("trained model") + +# Save trained model with pickle +with open(args.output_file, "wb") as svm_file: + pkl.dump(svm, svm_file, protocol=pkl.HIGHEST_PROTOCOL) diff --git a/scripts/pars_dsp_dplms.py b/scripts/pars_dsp_dplms.py new file mode 100644 index 0000000..bcf1ac0 --- /dev/null +++ b/scripts/pars_dsp_dplms.py @@ -0,0 +1,149 @@ +import argparse +import json +import logging +import os +import pathlib +import pickle as pkl +import time + +os.environ["LGDO_CACHE"] = "false" +os.environ["LGDO_BOUNDSCHECK"] = "false" +os.environ["DSPEED_CACHE"] = "false" +os.environ["DSPEED_BOUNDSCHECK"] = "false" +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + +import lgdo.lh5 as lh5 +import numpy as np +from legendmeta import LegendMetadata +from legendmeta.catalog import Props +from lgdo import Array, Table +from pygama.pargen.dplms_ge_dict import dplms_ge_dict + +argparser = argparse.ArgumentParser() +argparser.add_argument("--fft_raw_filelist", help="fft_raw_filelist", type=str) +argparser.add_argument("--peak_file", help="tcm_filelist", type=str, required=True) +argparser.add_argument("--inplots", help="in_plot_path", type=str) + +argparser.add_argument("--log", help="log_file", type=str) +argparser.add_argument("--database", help="database", type=str, required=True) +argparser.add_argument("--configs", help="configs", type=str, required=True) + +argparser.add_argument("--datatype", help="Datatype", type=str, required=True) +argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True) +argparser.add_argument("--channel", help="Channel", type=str, required=True) + +argparser.add_argument("--dsp_pars", help="dsp_pars", type=str, required=True) +argparser.add_argument("--lh5_path", help="lh5_path", type=str, required=True) +argparser.add_argument("--plot_path", help="plot_path", type=str) + +args = argparser.parse_args() + +logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") +logging.getLogger("numba").setLevel(logging.INFO) +logging.getLogger("parse").setLevel(logging.INFO) +logging.getLogger("lgdo").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("dspeed.processing_chain").setLevel(logging.INFO) +logging.getLogger("legendmeta").setLevel(logging.INFO) + +log = logging.getLogger(__name__) +sto = lh5.LH5Store() + +conf = LegendMetadata(path=args.configs) +configs = conf.on(args.timestamp, system=args.datatype) +dsp_config = configs["snakemake_rules"]["pars_dsp_dplms"]["inputs"]["proc_chain"][args.channel] + +dplms_json = configs["snakemake_rules"]["pars_dsp_dplms"]["inputs"]["dplms_pars"][args.channel] +dplms_dict = Props.read_from(dplms_json) + +db_dict = Props.read_from(args.database) + +if dplms_dict["run_dplms"] is True: + with open(args.fft_raw_filelist) as f: + fft_files = sorted(f.read().splitlines()) + + t0 = time.time() + log.info("\nLoad fft data") + energies = sto.read(f"{args.channel}/raw/daqenergy", fft_files)[0] + idxs = np.where(energies.nda == 0)[0] + raw_fft = sto.read( + f"{args.channel}/raw", fft_files, n_rows=dplms_dict["n_baselines"], idx=idxs + )[0] + t1 = time.time() + log.info(f"Time to load fft data {(t1-t0):.2f} s, total events {len(raw_fft)}") + + log.info("\nRunning event selection") + peaks_kev = np.array(dplms_dict["peaks_kev"]) + kev_widths = [tuple(kev_width) for kev_width in dplms_dict["kev_widths"]] + + peaks_rounded = [int(peak) for peak in peaks_kev] + peaks = sto.read(f"{args.channel}/raw", args.peak_file, field_mask=["peak"])[0]["peak"].nda + ids = np.in1d(peaks, peaks_rounded) + peaks = peaks[ids] + idx_list = [np.where(peaks == peak)[0] for peak in peaks_rounded] + + raw_cal = sto.read(f"{args.channel}/raw", args.peak_file, idx=ids)[0] + log.info(f"Time to run event selection {(time.time()-t1):.2f} s, total events {len(raw_cal)}") + + if isinstance(dsp_config, (str, list)): + dsp_config = Props.read_from(dsp_config) + + if args.plot_path: + out_dict, plot_dict = dplms_ge_dict( + raw_fft, + raw_cal, + dsp_config, + db_dict, + dplms_dict, + display=1, + ) + if args.inplots: + with open(args.inplots, "rb") as r: + inplot_dict = pkl.load(r) + inplot_dict.update({"dplms": plot_dict}) + + else: + out_dict = dplms_ge_dict( + raw_fft, + raw_cal, + dsp_config, + db_dict, + dplms_dict, + ) + + coeffs = out_dict["dplms"].pop("coefficients") + dplms_pars = Table(col_dict={"coefficients": Array(coeffs)}) + out_dict["dplms"][ + "coefficients" + ] = f"loadlh5('{args.lh5_path}', '{args.channel}/dplms/coefficients')" + + log.info(f"DPLMS creation finished in {(time.time()-t0)/60} minutes") +else: + out_dict = {} + dplms_pars = Table(col_dict={"coefficients": Array([])}) + if args.inplots: + with open(args.inplots, "rb") as r: + inplot_dict = pkl.load(r) + else: + inplot_dict = {} + +db_dict.update(out_dict) + +pathlib.Path(os.path.dirname(args.lh5_path)).mkdir(parents=True, exist_ok=True) +sto.write( + Table(col_dict={"dplms": dplms_pars}), + name=args.channel, + lh5_file=args.lh5_path, + wo_mode="overwrite", +) + +pathlib.Path(os.path.dirname(args.dsp_pars)).mkdir(parents=True, exist_ok=True) +with open(args.dsp_pars, "w") as w: + json.dump(db_dict, w, indent=2) + +if args.plot_path: + pathlib.Path(os.path.dirname(args.plot_path)).mkdir(parents=True, exist_ok=True) + with open(args.plot_path, "wb") as f: + pkl.dump(inplot_dict, f, protocol=pkl.HIGHEST_PROTOCOL) diff --git a/scripts/pars_dsp_eopt.py b/scripts/pars_dsp_eopt.py index 59fe8ec..86b5f7b 100644 --- a/scripts/pars_dsp_eopt.py +++ b/scripts/pars_dsp_eopt.py @@ -11,22 +11,31 @@ os.environ["LGDO_BOUNDSCHECK"] = "false" os.environ["DSPEED_CACHE"] = "false" os.environ["DSPEED_BOUNDSCHECK"] = "false" +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" import lgdo.lh5 as lh5 import numpy as np -import pygama.math.peak_fitting as pgf -import pygama.pargen.energy_optimisation as om +import pygama.pargen.energy_optimisation as om # noqa: F401 import sklearn.gaussian_process.kernels as ker +from dspeed.units import unit_registry as ureg from legendmeta import LegendMetadata from legendmeta.catalog import Props -from pygama.pargen.dsp_optimize import run_one_dsp -from pygama.pargen.utils import get_tcm_pulser_ids +from pygama.math.distributions import hpge_peak +from pygama.pargen.dsp_optimize import ( + BayesianOptimizer, + run_bayesian_optimisation, + run_one_dsp, +) warnings.filterwarnings(action="ignore", category=RuntimeWarning) +warnings.filterwarnings(action="ignore", category=np.RankWarning) + argparser = argparse.ArgumentParser() -argparser.add_argument("--raw_filelist", help="raw_filelist", type=str) -argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=True) + +argparser.add_argument("--peak_file", help="tcm_filelist", type=str, required=True) + argparser.add_argument("--decay_const", help="decay_const", type=str, required=True) argparser.add_argument("--configs", help="configs", type=str, required=True) argparser.add_argument("--inplots", help="in_plot_path", type=str) @@ -41,7 +50,6 @@ argparser.add_argument("--qbb_grid_path", help="qbb_grid_path", type=str) argparser.add_argument("--plot_path", help="plot_path", type=str) - argparser.add_argument("--plot_save_path", help="plot_save_path", type=str, required=False) args = argparser.parse_args() @@ -51,12 +59,12 @@ logging.getLogger("lgdo").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO) -logging.getLogger("pygama.dsp.processing_chain").setLevel(logging.INFO) +logging.getLogger("dspeed.processing_chain").setLevel(logging.INFO) +logging.getLogger("legendmeta").setLevel(logging.INFO) log = logging.getLogger(__name__) - - +sto = lh5.LH5Store() t0 = time.time() conf = LegendMetadata(path=args.configs) @@ -70,106 +78,51 @@ db_dict = Props.read_from(args.decay_const) if opt_dict.pop("run_eopt") is True: - with open(args.raw_filelist) as f: - files = f.read().splitlines() - - raw_files = sorted(files) - - # get pulser mask from tcm files - with open(args.tcm_filelist) as f: - tcm_files = f.read().splitlines() - tcm_files = sorted(np.unique(tcm_files)) - ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, opt_dict.pop("pulser_multiplicity_threshold") - ) - - peaks_keV = np.array(opt_dict["peaks"]) + peaks_kev = np.array(opt_dict["peaks"]) kev_widths = [tuple(kev_width) for kev_width in opt_dict["kev_widths"]] kwarg_dicts_cusp = [] kwarg_dicts_trap = [] kwarg_dicts_zac = [] - for peak in peaks_keV: - peak_idx = np.where(peaks_keV == peak)[0][0] + for peak in peaks_kev: + peak_idx = np.where(peaks_kev == peak)[0][0] kev_width = kev_widths[peak_idx] - if peak == 238.632: - kwarg_dicts_cusp.append( - { - "parameter": "cuspEmax", - "func": pgf.extended_gauss_step_pdf, - "gof_func": pgf.gauss_step_pdf, - "peak": peak, - "kev_width": kev_width, - } - ) - kwarg_dicts_zac.append( - { - "parameter": "zacEmax", - "func": pgf.extended_gauss_step_pdf, - "gof_func": pgf.gauss_step_pdf, - "peak": peak, - "kev_width": kev_width, - } - ) - kwarg_dicts_trap.append( - { - "parameter": "trapEmax", - "func": pgf.extended_gauss_step_pdf, - "gof_func": pgf.gauss_step_pdf, - "peak": peak, - "kev_width": kev_width, - } - ) - else: - kwarg_dicts_cusp.append( - { - "parameter": "cuspEmax", - "func": pgf.extended_radford_pdf, - "gof_func": pgf.radford_pdf, - "peak": peak, - "kev_width": kev_width, - } - ) - kwarg_dicts_zac.append( - { - "parameter": "zacEmax", - "func": pgf.extended_radford_pdf, - "gof_func": pgf.radford_pdf, - "peak": peak, - "kev_width": kev_width, - } - ) - kwarg_dicts_trap.append( - { - "parameter": "trapEmax", - "func": pgf.extended_radford_pdf, - "gof_func": pgf.radford_pdf, - "peak": peak, - "kev_width": kev_width, - } - ) - sto = lh5.LH5Store() - idx_events, idx_list = om.event_selection( - raw_files, - f"{args.channel}/raw", - dsp_config, - db_dict, - peaks_keV, - np.arange(0, len(peaks_keV), 1).tolist(), - kev_widths, - pulser_mask=mask, - cut_parameters=opt_dict["cut_parameters"], - n_events=opt_dict["n_events"], - threshold=opt_dict["threshold"], - wf_field=opt_dict["wf_field"], - ) - tb_data = sto.read( - f"{args.channel}/raw", - raw_files, - idx=idx_events, - n_rows=opt_dict["n_events"], - )[0] + kwarg_dicts_cusp.append( + { + "parameter": "cuspEmax", + "func": hpge_peak, + "peak": peak, + "kev_width": kev_width, + "bin_width": 5, + } + ) + kwarg_dicts_zac.append( + { + "parameter": "zacEmax", + "func": hpge_peak, + "peak": peak, + "kev_width": kev_width, + "bin_width": 5, + } + ) + kwarg_dicts_trap.append( + { + "parameter": "trapEmax", + "func": hpge_peak, + "peak": peak, + "kev_width": kev_width, + "bin_width": 5, + } + ) + + peaks_rounded = [int(peak) for peak in peaks_kev] + peaks = sto.read(f"{args.channel}/raw", args.peak_file, field_mask=["peak"])[0]["peak"].nda + ids = np.in1d(peaks, peaks_rounded) + peaks = peaks[ids] + idx_list = [np.where(peaks == peak)[0] for peak in peaks_rounded] + + tb_data = sto.read(f"{args.channel}/raw", args.peak_file, idx=ids)[0] t1 = time.time() log.info(f"Data Loaded in {(t1-t0)/60} minutes") @@ -182,6 +135,7 @@ init_data = run_one_dsp(tb_data, dsp_config, db_dict=db_dict, verbosity=0) full_dt = (init_data["tp_99"].nda - init_data["tp_0_est"].nda)[idx_list[-1]] flat_val = np.ceil(1.1 * np.nanpercentile(full_dt, 99) / 100) / 10 + if flat_val < 1.0: flat_val = 1.0 elif flat_val > 4: @@ -201,26 +155,27 @@ kwarg_dict = [ { "peak_dicts": kwarg_dicts_cusp, - "ctc_param": "QDrift", + "ctc_param": "dt_eff", "idx_list": idx_list, - "peaks_keV": peaks_keV, + "peaks_kev": peaks_kev, }, { "peak_dicts": kwarg_dicts_zac, - "ctc_param": "QDrift", + "ctc_param": "dt_eff", "idx_list": idx_list, - "peaks_keV": peaks_keV, + "peaks_kev": peaks_kev, }, { "peak_dicts": kwarg_dicts_trap, - "ctc_param": "QDrift", + "ctc_param": "dt_eff", "idx_list": idx_list, - "peaks_keV": peaks_keV, + "peaks_kev": peaks_kev, }, ] fom = eval(opt_dict["fom"]) - + out_field = opt_dict["fom_field"] + out_err_field = opt_dict["fom_err_field"] sample_x = np.array(opt_dict["initial_samples"]) results_cusp = [] @@ -246,18 +201,18 @@ res = fom(tb_out, kwarg_dict[0]) results_cusp.append(res) - sample_y_cusp.append(res["y_val"]) - err_y_cusp.append(res["y_err"]) + sample_y_cusp.append(res[out_field]) + err_y_cusp.append(res[out_err_field]) res = fom(tb_out, kwarg_dict[1]) results_zac.append(res) - sample_y_zac.append(res["y_val"]) - err_y_zac.append(res["y_err"]) + sample_y_zac.append(res[out_field]) + err_y_zac.append(res[out_err_field]) res = fom(tb_out, kwarg_dict[2]) results_trap.append(res) - sample_y_trap.append(res["y_val"]) - err_y_trap.append(res["y_err"]) + sample_y_trap.append(res[out_field]) + err_y_trap.append(res[out_err_field]) log.info(f"{i+1} Finished") @@ -295,23 +250,43 @@ + ker.WhiteKernel(noise_level=0.1, noise_level_bounds=(1e-5, 1e1)) ) - bopt_cusp = om.BayesianOptimizer( - acq_func=opt_dict["acq_func"], batch_size=opt_dict["batch_size"], kernel=kernel + lambda_param = 5 + sampling_rate = tb_data["waveform_presummed"]["dt"][0] + sampling_unit = ureg.Quantity(tb_data["waveform_presummed"]["dt"].attrs["units"]) + waveform_sampling = sampling_rate * sampling_unit + + bopt_cusp = BayesianOptimizer( + acq_func=opt_dict["acq_func"], + batch_size=opt_dict["batch_size"], + kernel=kernel, + sampling_rate=waveform_sampling, + fom_value=out_field, + fom_error=out_err_field, ) - bopt_cusp.lambda_param = 1 - bopt_cusp.add_dimension("cusp", "sigma", 1, 16, 2, "us") - - bopt_zac = om.BayesianOptimizer( - acq_func=opt_dict["acq_func"], batch_size=opt_dict["batch_size"], kernel=kernel + bopt_cusp.lambda_param = lambda_param + bopt_cusp.add_dimension("cusp", "sigma", 0.5, 16, True, "us") + + bopt_zac = BayesianOptimizer( + acq_func=opt_dict["acq_func"], + batch_size=opt_dict["batch_size"], + kernel=kernel, + sampling_rate=waveform_sampling, + fom_value=out_field, + fom_error=out_err_field, ) - bopt_zac.lambda_param = 1 - bopt_zac.add_dimension("zac", "sigma", 1, 16, 2, "us") - - bopt_trap = om.BayesianOptimizer( - acq_func=opt_dict["acq_func"], batch_size=opt_dict["batch_size"], kernel=kernel + bopt_zac.lambda_param = lambda_param + bopt_zac.add_dimension("zac", "sigma", 0.5, 16, True, "us") + + bopt_trap = BayesianOptimizer( + acq_func=opt_dict["acq_func"], + batch_size=opt_dict["batch_size"], + kernel=kernel, + sampling_rate=waveform_sampling, + fom_value=out_field, + fom_error=out_err_field, ) - bopt_trap.lambda_param = 1 - bopt_trap.add_dimension("etrap", "rise", 1, 12, 2, "us") + bopt_trap.lambda_param = lambda_param + bopt_trap.add_dimension("etrap", "rise", 1, 12, True, "us") bopt_cusp.add_initial_values(x_init=sample_x, y_init=sample_y_cusp, yerr_init=err_y_cusp) bopt_zac.add_initial_values(x_init=sample_x, y_init=sample_y_zac, yerr_init=err_y_zac) @@ -331,7 +306,7 @@ optimisers = [bopt_cusp, bopt_zac, bopt_trap] - out_param_dict, out_results_list = om.run_optimisation( + out_param_dict, out_results_list = run_bayesian_optimisation( tb_data, dsp_config, [fom], @@ -379,8 +354,10 @@ "expression": "trapEftp*(1+dt_eff*a)", "parameters": {"a": round(bopt_trap.optimal_results["alpha"], 9)}, } - - db_dict.update({"ctc_params": out_alpha_dict}) + if "ctc_params" in db_dict: + db_dict["ctc_params"].update(out_alpha_dict) + else: + db_dict.update({"ctc_params": out_alpha_dict}) pathlib.Path(os.path.dirname(args.qbb_grid_path)).mkdir(parents=True, exist_ok=True) with open(args.qbb_grid_path, "wb") as f: diff --git a/scripts/pars_dsp_event_selection.py b/scripts/pars_dsp_event_selection.py new file mode 100644 index 0000000..9100689 --- /dev/null +++ b/scripts/pars_dsp_event_selection.py @@ -0,0 +1,371 @@ +import argparse +import json +import logging +import os +import pathlib +import time +import warnings + +os.environ["LGDO_CACHE"] = "false" +os.environ["LGDO_BOUNDSCHECK"] = "false" +os.environ["DSPEED_CACHE"] = "false" +os.environ["DSPEED_BOUNDSCHECK"] = "false" +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + +from bisect import bisect_left + +import lgdo +import lgdo.lh5 as lh5 +import numpy as np +import pygama.math.histogram as pgh +import pygama.pargen.energy_cal as pgc +from legendmeta import LegendMetadata +from legendmeta.catalog import Props +from pygama.pargen.data_cleaning import generate_cuts, get_keys, get_tcm_pulser_ids +from pygama.pargen.dsp_optimize import run_one_dsp + +warnings.filterwarnings(action="ignore", category=RuntimeWarning) + + +def get_out_data( + raw_data, + dsp_data, + cut_dict, + e_lower_lim, + e_upper_lim, + ecal_pars, + raw_dict, + peak, + final_cut_field="is_valid_cal", + energy_param="trapTmax", +): + for outname, info in cut_dict.items(): + outcol = dsp_data.eval(info["expression"], info.get("parameters", None)) + dsp_data.add_column(outname, outcol) + + for outname, info in raw_dict.items(): + outcol = raw_data.eval(info["expression"], info.get("parameters", None)) + raw_data.add_column(outname, outcol) + + final_mask = ( + (dsp_data[energy_param].nda > e_lower_lim) + & (dsp_data[energy_param].nda < e_upper_lim) + & (dsp_data[final_cut_field].nda) + ) + + wavefrom_windowed = lgdo.WaveformTable( + t0=raw_data["waveform_windowed"]["t0"].nda[final_mask], + t0_units=raw_data["waveform_windowed"]["t0"].attrs["units"], + dt=raw_data["waveform_windowed"]["dt"].nda[final_mask], + dt_units=raw_data["waveform_windowed"]["dt"].attrs["units"], + values=raw_data["waveform_windowed"]["values"].nda[final_mask], + ) + wavefrom_presummed = lgdo.WaveformTable( + t0=raw_data["waveform_presummed"]["t0"].nda[final_mask], + t0_units=raw_data["waveform_presummed"]["t0"].attrs["units"], + dt=raw_data["waveform_presummed"]["dt"].nda[final_mask], + dt_units=raw_data["waveform_presummed"]["dt"].attrs["units"], + values=raw_data["waveform_presummed"]["values"].nda[final_mask], + ) + + out_tbl = lgdo.Table( + col_dict={ + "waveform_presummed": wavefrom_presummed, + "waveform_windowed": wavefrom_windowed, + "presum_rate": lgdo.Array(raw_data["presum_rate"].nda[final_mask]), + "timestamp": lgdo.Array(raw_data["timestamp"].nda[final_mask]), + "baseline": lgdo.Array(raw_data["baseline"].nda[final_mask]), + "daqenergy": lgdo.Array(raw_data["daqenergy"].nda[final_mask]), + "daqenergy_cal": lgdo.Array(raw_data["daqenergy_cal"].nda[final_mask]), + "trapTmax_cal": lgdo.Array(dsp_data["trapTmax"].nda[final_mask] * ecal_pars), + "peak": lgdo.Array(np.full(len(np.where(final_mask)[0]), int(peak))), + } + ) + return out_tbl, len(np.where(final_mask)[0]) + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--raw_filelist", help="raw_filelist", type=str) + argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=False) + argparser.add_argument("--pulser_file", help="pulser_file", type=str, required=False) + + argparser.add_argument("--decay_const", help="decay_const", type=str, required=True) + argparser.add_argument("--configs", help="configs", type=str, required=True) + argparser.add_argument("--raw_cal", help="raw_cal", type=str, nargs="*", required=True) + + argparser.add_argument("--log", help="log_file", type=str) + + argparser.add_argument("--datatype", help="Datatype", type=str, required=True) + argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True) + argparser.add_argument("--channel", help="Channel", type=str, required=True) + + argparser.add_argument("--peak_file", help="peak_file", type=str, required=True) + args = argparser.parse_args() + + logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") + logging.getLogger("numba").setLevel(logging.INFO) + logging.getLogger("parse").setLevel(logging.INFO) + logging.getLogger("lgdo").setLevel(logging.INFO) + logging.getLogger("h5py").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + logging.getLogger("legendmeta").setLevel(logging.INFO) + logging.getLogger("dspeed.processing_chain").setLevel(logging.INFO) + + log = logging.getLogger(__name__) + sto = lh5.LH5Store() + t0 = time.time() + + conf = LegendMetadata(path=args.configs) + configs = conf.on(args.timestamp, system=args.datatype) + dsp_config = configs["snakemake_rules"]["pars_dsp_peak_selection"]["inputs"][ + "processing_chain" + ][args.channel] + peak_json = configs["snakemake_rules"]["pars_dsp_peak_selection"]["inputs"]["peak_config"][ + args.channel + ] + + peak_dict = Props.read_from(peak_json) + db_dict = Props.read_from(args.decay_const) + + pathlib.Path(os.path.dirname(args.peak_file)).mkdir(parents=True, exist_ok=True) + if peak_dict.pop("run_selection") is True: + log.debug("Starting peak selection") + rng = np.random.default_rng() + rand_num = f"{rng.integers(0,99999):05d}" + temp_output = f"{args.peak_file}.{rand_num}" + + with open(args.raw_filelist) as f: + files = f.read().splitlines() + raw_files = sorted(files) + + if args.pulser_file: + with open(args.pulser_file) as f: + pulser_dict = json.load(f) + mask = np.array(pulser_dict["mask"]) + + elif args.tcm_filelist: + # get pulser mask from tcm files + with open(args.tcm_filelist) as f: + tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, peak_dict["pulser_multiplicity_threshold"] + ) + else: + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) + + raw_dict = Props.read_from(args.raw_cal)[args.channel]["pars"]["operations"] + + peaks_kev = peak_dict["peaks"] + kev_widths = peak_dict["kev_widths"] + cut_parameters = peak_dict["cut_parameters"] + n_events = peak_dict["n_events"] + final_cut_field = peak_dict["final_cut_field"] + energy_parameter = peak_dict.get("energy_parameter", "trapTmax") + + lh5_path = f"{args.channel}/raw" + + if not isinstance(kev_widths, list): + kev_widths = [kev_widths] + + if lh5_path[-1] != "/": + lh5_path += "/" + + raw_fields = [field.replace(lh5_path, "") for field in lh5.ls(raw_files[0], lh5_path)] + + tb = sto.read(lh5_path, raw_files, field_mask=["daqenergy", "t_sat_lo", "timestamp"])[0] + + discharges = tb["t_sat_lo"].nda > 0 + discharge_timestamps = np.where(tb["timestamp"].nda[discharges])[0] + is_recovering = np.full(len(tb), False, dtype=bool) + for tstamp in discharge_timestamps: + is_recovering = is_recovering | np.where( + (((tb["timestamp"].nda - tstamp) < 0.01) & ((tb["timestamp"].nda - tstamp) > 0)), + True, + False, + ) + + for outname, info in raw_dict.items(): + outcol = tb.eval(info["expression"], info.get("parameters", None)) + tb.add_column(outname, outcol) + + rough_energy = tb["daqenergy_cal"].nda + + masks = {} + for peak, kev_width in zip(peaks_kev, kev_widths): + e_mask = ( + (rough_energy > peak - 1.1 * kev_width[0]) + & (rough_energy < peak + 1.1 * kev_width[0]) + & (~mask) + ) + masks[peak] = np.where(e_mask & (~is_recovering))[0] + log.debug(f"{len(masks[peak])} events found in energy range for {peak}") + + input_data = sto.read(f"{lh5_path}", raw_files, n_rows=10000, idx=np.where(~mask)[0])[0] + + if isinstance(dsp_config, str): + dsp_config = Props.read_from(dsp_config) + + dsp_config["outputs"] = [ + *get_keys(dsp_config["outputs"], cut_parameters), + energy_parameter, + ] + + log.debug("Processing data") + tb_data = run_one_dsp(input_data, dsp_config, db_dict=db_dict) + + if cut_parameters is not None: + cut_dict = generate_cuts(tb_data, cut_parameters) + log.debug(f"Cuts are calculated: {json.dumps(cut_dict, indent=2)}") + else: + cut_dict = None + + pk_dicts = {} + for peak, kev_width in zip(peaks_kev, kev_widths): + pk_dicts[peak] = { + "idxs": (masks[peak],), + "n_rows_read": 0, + "obj_buf_start": 0, + "obj_buf": None, + "kev_width": kev_width, + } + + for file in raw_files: + log.debug(os.path.basename(file)) + for peak, peak_dict in pk_dicts.items(): + if peak_dict["idxs"] is not None: + # idx is a long continuous array + n_rows_i = sto.read_n_rows(lh5_path, file) + # find the length of the subset of idx that contains indices + # that are less than n_rows_i + n_rows_to_read_i = bisect_left(peak_dict["idxs"][0], n_rows_i) + # now split idx into idx_i and the remainder + idx_i = (peak_dict["idxs"][0][:n_rows_to_read_i],) + peak_dict["idxs"] = (peak_dict["idxs"][0][n_rows_to_read_i:] - n_rows_i,) + if len(idx_i[0]) > 0: + peak_dict["obj_buf"], n_rows_read_i = sto.read( + lh5_path, + file, + start_row=0, + idx=idx_i, + obj_buf=peak_dict["obj_buf"], + obj_buf_start=peak_dict["obj_buf_start"], + ) + + peak_dict["n_rows_read"] += n_rows_read_i + log.debug(f'{peak}: {peak_dict["n_rows_read"]}') + peak_dict["obj_buf_start"] += n_rows_read_i + if peak_dict["n_rows_read"] >= 10000 or file == raw_files[-1]: + if "e_lower_lim" not in peak_dict: + tb_out = run_one_dsp(peak_dict["obj_buf"], dsp_config, db_dict=db_dict) + energy = tb_out[energy_parameter].nda + + init_bin_width = ( + 2 + * (np.nanpercentile(energy, 75) - np.nanpercentile(energy, 25)) + * len(energy) ** (-1 / 3) + ) + + if init_bin_width > 2: + init_bin_width = 2 + + hist, bins, var = pgh.get_hist( + energy, + range=( + np.floor(np.nanpercentile(energy, 1)), + np.ceil(np.nanpercentile(energy, 99)), + ), + dx=init_bin_width, + ) + peak_loc = pgh.get_bin_centers(bins)[np.nanargmax(hist)] + + peak_top_pars = pgc.hpge_fit_energy_peak_tops( + hist, + bins, + var, + [peak_loc], + n_to_fit=7, + )[0][0] + try: + mu = peak_top_pars[0] + except Exception: + mu = np.nan + if mu is None or np.isnan(mu): + log.debug("Fit failed, using max guess") + rough_adc_to_kev = peak / peak_loc + e_lower_lim = ( + peak_loc - (1.5 * peak_dict["kev_width"][0]) / rough_adc_to_kev + ) + e_upper_lim = ( + peak_loc + (1.5 * peak_dict["kev_width"][1]) / rough_adc_to_kev + ) + hist, bins, var = pgh.get_hist( + energy, + range=(int(e_lower_lim), int(e_upper_lim)), + dx=init_bin_width, + ) + mu = pgh.get_bin_centers(bins)[np.nanargmax(hist)] + + updated_adc_to_kev = peak / mu + e_lower_lim = mu - (peak_dict["kev_width"][0]) / updated_adc_to_kev + e_upper_lim = mu + (peak_dict["kev_width"][1]) / updated_adc_to_kev + log.info( + f"{peak}: lower lim is :{e_lower_lim}, upper lim is {e_upper_lim}" + ) + peak_dict["e_lower_lim"] = e_lower_lim + peak_dict["e_upper_lim"] = e_upper_lim + peak_dict["ecal_par"] = updated_adc_to_kev + + out_tbl, n_wfs = get_out_data( + peak_dict["obj_buf"], + tb_out, + cut_dict, + e_lower_lim, + e_upper_lim, + peak_dict["ecal_par"], + raw_dict, + int(peak), + final_cut_field=final_cut_field, + energy_param=energy_parameter, + ) + sto.write(out_tbl, name=lh5_path, lh5_file=temp_output, wo_mode="a") + peak_dict["obj_buf"] = None + peak_dict["obj_buf_start"] = 0 + peak_dict["n_events"] = n_wfs + log.debug(f'found {peak_dict["n_events"]} events for {peak}') + else: + if peak_dict["obj_buf"] is not None and len(peak_dict["obj_buf"]) > 0: + tb_out = run_one_dsp( + peak_dict["obj_buf"], dsp_config, db_dict=db_dict + ) + out_tbl, n_wfs = get_out_data( + peak_dict["obj_buf"], + tb_out, + cut_dict, + peak_dict["e_lower_lim"], + peak_dict["e_upper_lim"], + peak_dict["ecal_par"], + raw_dict, + int(peak), + final_cut_field=final_cut_field, + energy_param=energy_parameter, + ) + peak_dict["n_events"] += n_wfs + sto.write( + out_tbl, name=lh5_path, lh5_file=temp_output, wo_mode="a" + ) + peak_dict["obj_buf"] = None + peak_dict["obj_buf_start"] = 0 + log.debug(f'found {peak_dict["n_events"]} events for {peak}') + if peak_dict["n_events"] >= n_events: + peak_dict["idxs"] = None + log.debug(f"{peak} has reached the required number of events") + + else: + pathlib.Path(temp_output).touch() + + log.debug(f"event selection completed in {time.time()-t0} seconds") + os.rename(temp_output, args.peak_file) diff --git a/scripts/pars_dsp_nopt.py b/scripts/pars_dsp_nopt.py index 1b2e798..c2c393d 100644 --- a/scripts/pars_dsp_nopt.py +++ b/scripts/pars_dsp_nopt.py @@ -10,13 +10,15 @@ os.environ["LGDO_BOUNDSCHECK"] = "false" os.environ["DSPEED_CACHE"] = "false" os.environ["DSPEED_BOUNDSCHECK"] = "false" +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" import lgdo.lh5 as lh5 import numpy as np import pygama.pargen.noise_optimization as pno from legendmeta import LegendMetadata from legendmeta.catalog import Props -from pygama.pargen.cuts import generate_cuts, get_cut_indexes +from pygama.pargen.data_cleaning import generate_cuts, get_cut_indexes from pygama.pargen.dsp_optimize import run_one_dsp sto = lh5.LH5Store() @@ -43,7 +45,8 @@ logging.getLogger("parse").setLevel(logging.INFO) logging.getLogger("lgdo").setLevel(logging.INFO) logging.getLogger("h5py._conv").setLevel(logging.INFO) -logging.getLogger("pygama.dsp.processing_chain").setLevel(logging.INFO) +logging.getLogger("dspeed.processing_chain").setLevel(logging.INFO) +logging.getLogger("legendmeta").setLevel(logging.INFO) log = logging.getLogger(__name__) @@ -75,7 +78,7 @@ log.info(f"Select baselines {len(tb_data)}") dsp_data = run_one_dsp(tb_data, dsp_config) - cut_dict = generate_cuts(dsp_data, parameters=opt_dict.pop("cut_pars")) + cut_dict = generate_cuts(dsp_data, cut_dict=opt_dict.pop("cut_pars")) cut_idxs = get_cut_indexes(dsp_data, cut_dict) tb_data = sto.read( f"{args.channel}/raw", raw_files, n_rows=opt_dict.pop("n_events"), idx=idxs[cut_idxs] @@ -87,7 +90,7 @@ if args.plot_path: out_dict, plot_dict = pno.noise_optimization( - tb_data, dsp_config, db_dict, opt_dict, args.channel, display=1 + tb_data, dsp_config, db_dict.copy(), opt_dict, args.channel, display=1 ) else: out_dict = pno.noise_optimization( diff --git a/scripts/pars_dsp_svm.py b/scripts/pars_dsp_svm.py new file mode 100644 index 0000000..40f0a25 --- /dev/null +++ b/scripts/pars_dsp_svm.py @@ -0,0 +1,36 @@ +import argparse +import json +import logging +import os +import pathlib + +argparser = argparse.ArgumentParser() +argparser.add_argument("--log", help="log file", type=str) +argparser.add_argument("--output_file", help="output par file", type=str, required=True) +argparser.add_argument("--input_file", help="input par file", type=str, required=True) +argparser.add_argument("--svm_file", help="svm file", required=True) +args = argparser.parse_args() + + +if args.log is not None: + pathlib.Path(os.path.dirname(args.log)).mkdir(parents=True, exist_ok=True) + logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") +else: + logging.basicConfig(level=logging.DEBUG) + +logging.getLogger("parse").setLevel(logging.INFO) +logging.getLogger("lgdo").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) + +log = logging.getLogger(__name__) + +with open(args.input_file) as r: + par_data = json.load(r) + +file = f"'$_/{os.path.basename(args.svm_file)}'" + +par_data["svm"] = {"model_file": file} + +pathlib.Path(os.path.dirname(args.output_file)).mkdir(parents=True, exist_ok=True) +with open(args.output_file, "w") as w: + json.dump(par_data, w, indent=4) diff --git a/scripts/pars_dsp_tau.py b/scripts/pars_dsp_tau.py index 0a315ff..8064308 100644 --- a/scripts/pars_dsp_tau.py +++ b/scripts/pars_dsp_tau.py @@ -9,13 +9,16 @@ os.environ["LGDO_BOUNDSCHECK"] = "false" os.environ["DSPEED_CACHE"] = "false" os.environ["DSPEED_BOUNDSCHECK"] = "false" +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" import lgdo.lh5 as lh5 import numpy as np from legendmeta import LegendMetadata from legendmeta.catalog import Props -from pygama.pargen.extract_tau import dsp_preprocess_decay_const -from pygama.pargen.utils import get_tcm_pulser_ids +from pygama.pargen.data_cleaning import get_cut_indexes, get_tcm_pulser_ids +from pygama.pargen.dsp_optimize import run_one_dsp +from pygama.pargen.extract_tau import ExtractTau argparser = argparse.ArgumentParser() argparser.add_argument("--configs", help="configs path", type=str, required=True) @@ -25,8 +28,11 @@ argparser.add_argument("--channel", help="Channel", type=str, required=True) argparser.add_argument("--plot_path", help="plot path", type=str, required=False) argparser.add_argument("--output_file", help="output file", type=str, required=True) + +argparser.add_argument("--pulser_file", help="pulser file", type=str, required=False) + argparser.add_argument("--raw_files", help="input files", nargs="*", type=str) -argparser.add_argument("--tcm_files", help="tcm_files", nargs="*", type=str) +argparser.add_argument("--tcm_files", help="tcm_files", nargs="*", type=str, required=False) args = argparser.parse_args() logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") @@ -35,8 +41,10 @@ logging.getLogger("lgdo").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("legendmeta").setLevel(logging.INFO) sto = lh5.LH5Store() +log = logging.getLogger(__name__) configs = LegendMetadata(path=args.configs) config_dict = configs.on(args.timestamp, system=args.datatype) @@ -48,6 +56,7 @@ kwarg_dict = Props.read_from(kwarg_dict) if kwarg_dict["run_tau"] is True: + dsp_config = Props.read_from(channel_dict) kwarg_dict.pop("run_tau") if isinstance(args.raw_files, list) and args.raw_files[0].split(".")[-1] == "filelist": input_file = args.raw_files[0] @@ -56,22 +65,38 @@ else: input_file = args.raw_files - if isinstance(args.tcm_files, list) and args.tcm_files[0].split(".")[-1] == "filelist": - tcm_files = args.tcm_files[0] - with open(tcm_files) as f: + if args.pulser_file: + with open(args.pulser_file) as f: + pulser_dict = json.load(f) + mask = np.array(pulser_dict["mask"]) + + elif args.tcm_filelist: + # get pulser mask from tcm files + with open(args.tcm_filelist) as f: tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, kwarg_dict["pulser_multiplicity_threshold"] + ) else: - tcm_files = args.tcm_files - # get pulser mask from tcm files - tcm_files = sorted(np.unique(tcm_files)) - ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") - ) - data = sto.read(f"{args.channel}/raw", input_file, field_mask=["daqenergy", "timestamp"])[ - 0 - ].view_as("pd") + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) + + data = sto.read( + f"{args.channel}/raw", input_file, field_mask=["daqenergy", "timestamp", "t_sat_lo"] + )[0].view_as("pd") threshold = kwarg_dict.pop("threshold") - cuts = np.where((data.daqenergy.to_numpy() > threshold) & (~mask))[0] + + discharges = data["t_sat_lo"] > 0 + discharge_timestamps = np.where(data["timestamp"][discharges])[0] + is_recovering = np.full(len(data), False, dtype=bool) + for tstamp in discharge_timestamps: + is_recovering = is_recovering | np.where( + (((data["timestamp"] - tstamp) < 0.01) & ((data["timestamp"] - tstamp) > 0)), + True, + False, + ) + cuts = np.where((data.daqenergy.to_numpy() > threshold) & (~mask) & (~is_recovering))[0] tb_data = sto.read( f"{args.channel}/raw", @@ -80,12 +105,30 @@ n_rows=kwarg_dict.pop("n_events"), )[0] - out_dict, plot_dict = dsp_preprocess_decay_const( - tb_data, channel_dict, **kwarg_dict, display=1 - ) + tb_out = run_one_dsp(tb_data, dsp_config) + log.debug("Processed Data") + cut_parameters = kwarg_dict.get("cut_parameters", None) + if cut_parameters is not None: + idxs = get_cut_indexes(tb_out, cut_parameters=cut_parameters) + log.debug("Applied cuts") + log.debug(f"{len(idxs)} events passed cuts") + else: + idxs = np.full(len(tb_out), True, dtype=bool) + + tau = ExtractTau(dsp_config, kwarg_dict["wf_field"]) + slopes = tb_out["tail_slope"].nda + log.debug("Calculating pz constant") + + tau.get_decay_constant(slopes[idxs], tb_data[kwarg_dict["wf_field"]]) if args.plot_path: pathlib.Path(os.path.dirname(args.plot_path)).mkdir(parents=True, exist_ok=True) + + plot_dict = tau.plot_waveforms_after_correction( + tb_data, "wf_pz", norm_param=kwarg_dict.get("norm_param", "pz_mean") + ) + plot_dict.update(tau.plot_slopes(slopes[idxs])) + with open(args.plot_path, "wb") as f: pkl.dump({"tau": plot_dict}, f, protocol=pkl.HIGHEST_PROTOCOL) else: @@ -93,4 +136,4 @@ pathlib.Path(os.path.dirname(args.output_file)).mkdir(parents=True, exist_ok=True) with open(args.output_file, "w") as f: - json.dump(out_dict, f, indent=4) + json.dump(tau.output_dict, f, indent=4) diff --git a/scripts/pars_hit_aoe.py b/scripts/pars_hit_aoe.py index 2f7167b..afb90a8 100644 --- a/scripts/pars_hit_aoe.py +++ b/scripts/pars_hit_aoe.py @@ -9,18 +9,48 @@ import warnings from typing import Callable +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + import numpy as np import pandas as pd from legendmeta import LegendMetadata from legendmeta.catalog import Props from pygama.pargen.AoE_cal import * # noqa: F403 -from pygama.pargen.AoE_cal import cal_aoe, pol1, sigma_fit, standard_aoe -from pygama.pargen.utils import get_tcm_pulser_ids, load_data +from pygama.pargen.AoE_cal import CalAoE, Pol1, SigmaFit, aoe_peak +from pygama.pargen.data_cleaning import get_tcm_pulser_ids +from pygama.pargen.utils import load_data log = logging.getLogger(__name__) warnings.filterwarnings(action="ignore", category=RuntimeWarning) +def get_results_dict(aoe_class): + return { + "cal_energy_param": aoe_class.cal_energy_param, + "dt_param": aoe_class.dt_param, + "rt_correction": aoe_class.dt_corr, + "1000-1300keV": aoe_class.timecorr_df.to_dict("index"), + "correction_fit_results": aoe_class.energy_corr_res_dict, + "low_cut": aoe_class.low_cut_val, + "high_cut": aoe_class.high_cut_val, + "low_side_sfs": aoe_class.low_side_sfs.to_dict("index"), + "2_side_sfs": aoe_class.two_side_sfs.to_dict("index"), + } + + +def fill_plot_dict(aoe_class, data, plot_options, plot_dict=None): + if plot_dict is not None: + for key, item in plot_options.items(): + if item["options"] is not None: + plot_dict[key] = item["function"](aoe_class, data, **item["options"]) + else: + plot_dict[key] = item["function"](aoe_class, data) + else: + plot_dict = {} + return plot_dict + + def aoe_calibration( data: pd.Dataframe, cal_dicts: dict, @@ -28,36 +58,34 @@ def aoe_calibration( energy_param: str, cal_energy_param: str, eres_func: Callable, - pdf: Callable = standard_aoe, + pdf: Callable = aoe_peak, selection_string: str = "", dt_corr: bool = False, dep_correct: bool = False, dt_cut: dict | None = None, high_cut_val: int = 3, - mean_func: Callable = pol1, - sigma_func: Callable = sigma_fit, - dep_acc: float = 0.9, + mean_func: Callable = Pol1, + sigma_func: Callable = SigmaFit, + # dep_acc: float = 0.9, dt_param: str = "dt_eff", comptBands_width: int = 20, plot_options: dict | None = None, ): data["AoE_Uncorr"] = data[current_param] / data[energy_param] - aoe = cal_aoe( - cal_dicts, - cal_energy_param, - eres_func, - pdf, - selection_string, - dt_corr, - dep_acc, - dep_correct, - dt_cut, - dt_param, - high_cut_val, - mean_func, - sigma_func, - comptBands_width, - plot_options if plot_options is not None else {}, + aoe = CalAoE( + cal_dicts=cal_dicts, + cal_energy_param=cal_energy_param, + eres_func=eres_func, + pdf=pdf, + selection_string=selection_string, + dt_corr=dt_corr, + dep_correct=dep_correct, + dt_cut=dt_cut, + dt_param=dt_param, + high_cut_val=high_cut_val, + mean_func=mean_func, + sigma_func=sigma_func, + compt_bands_width=comptBands_width, ) aoe.update_cal_dicts( @@ -71,12 +99,13 @@ def aoe_calibration( aoe.calibrate(data, "AoE_Uncorr") log.info("Calibrated A/E") - return cal_dicts, aoe.get_results_dict(), aoe.fill_plot_dict(data), aoe + return cal_dicts, get_results_dict(aoe), fill_plot_dict(aoe, data, plot_options), aoe argparser = argparse.ArgumentParser() argparser.add_argument("files", help="files", nargs="*", type=str) -argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=True) +argparser.add_argument("--pulser_file", help="pulser_file", type=str, required=False) +argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=False) argparser.add_argument("--ecal_file", help="ecal_file", type=str, required=True) argparser.add_argument("--eres_file", help="eres_file", type=str, required=True) argparser.add_argument("--inplots", help="in_plot_path", type=str, required=False) @@ -99,6 +128,7 @@ def aoe_calibration( logging.getLogger("lgdo").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("legendmeta").setLevel(logging.INFO) configs = LegendMetadata(path=args.configs) channel_dict = configs.on(args.timestamp, system=args.datatype)["snakemake_rules"][ @@ -118,11 +148,11 @@ def aoe_calibration( if kwarg_dict["run_aoe"] is True: kwarg_dict.pop("run_aoe") - pdf = eval(kwarg_dict.pop("pdf")) if "pdf" in kwarg_dict else standard_aoe + pdf = eval(kwarg_dict.pop("pdf")) if "pdf" in kwarg_dict else aoe_peak - sigma_func = eval(kwarg_dict.pop("sigma_func")) if "sigma_func" in kwarg_dict else sigma_fit + sigma_func = eval(kwarg_dict.pop("sigma_func")) if "sigma_func" in kwarg_dict else SigmaFit - mean_func = eval(kwarg_dict.pop("mean_func")) if "mean_func" in kwarg_dict else pol1 + mean_func = eval(kwarg_dict.pop("mean_func")) if "mean_func" in kwarg_dict else Pol1 if "plot_options" in kwarg_dict: for field, item in kwarg_dict["plot_options"].items(): @@ -172,13 +202,25 @@ def eres_func(x): return_selection_mask=True, ) - # get pulser mask from tcm files - with open(args.tcm_filelist) as f: - tcm_files = f.read().splitlines() - tcm_files = sorted(tcm_files) - ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") - ) + if args.pulser_file: + with open(args.pulser_file) as f: + pulser_dict = json.load(f) + mask = np.array(pulser_dict["mask"]) + if "pulser_multiplicity_threshold" in kwarg_dict: + kwarg_dict.pop("pulser_multiplicity_threshold") + + elif args.tcm_filelist: + # get pulser mask from tcm files + with open(args.tcm_filelist) as f: + tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") + ) + else: + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) + data["is_pulser"] = mask[threshold_mask] cal_dict, out_dict, plot_dict, obj = aoe_calibration( diff --git a/scripts/pars_hit_ecal.py b/scripts/pars_hit_ecal.py index a7b399e..b324b62 100644 --- a/scripts/pars_hit_ecal.py +++ b/scripts/pars_hit_ecal.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import copy import json import logging import os @@ -9,18 +10,22 @@ import warnings from datetime import datetime +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + import lgdo.lh5 as lh5 import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np -import pandas as pd +import pygama.math.distributions as pgf import pygama.math.histogram as pgh from legendmeta import LegendMetadata from legendmeta.catalog import Props from matplotlib.colors import LogNorm -from pygama.pargen.ecal_th import * # noqa: F403 -from pygama.pargen.ecal_th import apply_cuts, calibrate_parameter -from pygama.pargen.utils import get_tcm_pulser_ids, load_data +from pygama.math.distributions import nb_poly +from pygama.pargen.data_cleaning import get_mode_stdev, get_tcm_pulser_ids +from pygama.pargen.energy_cal import FWHMLinear, FWHMQuadratic, HPGeCalibration +from pygama.pargen.utils import load_data from scipy.stats import binned_statistic log = logging.getLogger(__name__) @@ -28,6 +33,238 @@ sto = lh5.LH5Store() warnings.filterwarnings(action="ignore", category=RuntimeWarning) +warnings.filterwarnings(action="ignore", category=np.RankWarning) + + +def plot_2614_timemap( + data, + cal_energy_param, + selection_string, + figsize=(12, 8), + fontsize=12, + erange=(2580, 2630), + dx=1, + time_dx=180, +): + plt.rcParams["figure.figsize"] = figsize + plt.rcParams["font.size"] = fontsize + + selection = data.query(f"{cal_energy_param}>2560&{cal_energy_param}<2660&{selection_string}") + + fig = plt.figure() + if len(selection) == 0: + pass + else: + time_bins = np.arange( + (np.amin(data["timestamp"]) // time_dx) * time_dx, + ((np.amax(data["timestamp"]) // time_dx) + 2) * time_dx, + time_dx, + ) + + plt.hist2d( + selection["timestamp"], + selection[cal_energy_param], + bins=[time_bins, np.arange(erange[0], erange[1] + dx, dx)], + norm=LogNorm(), + ) + + ticks, labels = plt.xticks() + plt.xlabel(f"Time starting : {datetime.utcfromtimestamp(ticks[0]).strftime('%d/%m/%y %H:%M')}") + plt.ylabel("Energy(keV)") + plt.ylim([erange[0], erange[1]]) + + plt.xticks( + ticks, + [datetime.utcfromtimestamp(tick).strftime("%H:%M") for tick in ticks], + ) + plt.close() + return fig + + +def plot_pulser_timemap( + data, + cal_energy_param, + selection_string, # noqa: ARG001 + pulser_field="is_pulser", + figsize=(12, 8), + fontsize=12, + dx=0.2, + time_dx=180, + n_spread=3, +): + plt.rcParams["figure.figsize"] = figsize + plt.rcParams["font.size"] = fontsize + + time_bins = np.arange( + (np.amin(data["timestamp"]) // time_dx) * time_dx, + ((np.amax(data["timestamp"]) // time_dx) + 2) * time_dx, + time_dx, + ) + + selection = data.query(pulser_field) + fig = plt.figure() + if len(selection) == 0: + pass + + else: + mean = np.nanpercentile(selection[cal_energy_param], 50) + spread = mean - np.nanpercentile(selection[cal_energy_param], 10) + + plt.hist2d( + selection["timestamp"], + selection[cal_energy_param], + bins=[ + time_bins, + np.arange(mean - n_spread * spread, mean + n_spread * spread + dx, dx), + ], + norm=LogNorm(), + ) + plt.ylim([mean - n_spread * spread, mean + n_spread * spread]) + ticks, labels = plt.xticks() + plt.xlabel(f"Time starting : {datetime.utcfromtimestamp(ticks[0]).strftime('%d/%m/%y %H:%M')}") + plt.ylabel("Energy(keV)") + + plt.xticks( + ticks, + [datetime.utcfromtimestamp(tick).strftime("%H:%M") for tick in ticks], + ) + plt.close() + return fig + + +def get_median(x): + if len(x[~np.isnan(x)]) >= 10: + return np.nan + else: + return np.nanpercentile(x, 50) + + +def get_err(x): + if len(x[~np.isnan(x)]) >= 10: + return np.nan + else: + return np.nanvar(x) / np.sqrt(len(x)) + + +def bin_pulser_stability( + data, + cal_energy_param, + selection_string, # noqa: ARG001 + pulser_field="is_pulser", + time_slice=180, +): + selection = data.query(pulser_field) + + utime_array = data["timestamp"] + select_energies = selection[cal_energy_param].to_numpy() + + time_bins = np.arange( + (np.amin(utime_array) // time_slice) * time_slice, + ((np.amax(utime_array) // time_slice) + 2) * time_slice, + time_slice, + ) + # bin time values + times_average = (time_bins[:-1] + time_bins[1:]) / 2 + + if len(selection) == 0: + return { + "time": times_average, + "energy": np.full_like(times_average, np.nan), + "spread": np.full_like(times_average, np.nan), + } + + par_average, _, _ = binned_statistic( + selection["timestamp"], select_energies, statistic=get_median, bins=time_bins + ) + par_error, _, _ = binned_statistic( + selection["timestamp"], select_energies, statistic=get_err, bins=time_bins + ) + + return {"time": times_average, "energy": par_average, "spread": par_error} + + +def bin_stability( + data, + cal_energy_param, + selection_string, + time_slice=180, + energy_range=(2585, 2660), +): + selection = data.query( + f"{cal_energy_param}>{energy_range[0]}&{cal_energy_param}<{energy_range[1]}&{selection_string}" + ) + + utime_array = data["timestamp"] + select_energies = selection[cal_energy_param].to_numpy() + + time_bins = np.arange( + (np.amin(utime_array) // time_slice) * time_slice, + ((np.amax(utime_array) // time_slice) + 2) * time_slice, + time_slice, + ) + # bin time values + times_average = (time_bins[:-1] + time_bins[1:]) / 2 + + if len(selection) == 0: + return { + "time": times_average, + "energy": np.full_like(times_average, np.nan), + "spread": np.full_like(times_average, np.nan), + } + + par_average, _, _ = binned_statistic( + selection["timestamp"], select_energies, statistic=get_median, bins=time_bins + ) + par_error, _, _ = binned_statistic( + selection["timestamp"], select_energies, statistic=get_err, bins=time_bins + ) + + return {"time": times_average, "energy": par_average, "spread": par_error} + + +def bin_spectrum( + data, + cal_energy_param, + selection_string, + cut_field="is_valid_cal", + pulser_field="is_pulser", + erange=(0, 3000), + dx=2, +): + bins = np.arange(erange[0], erange[1] + dx, dx) + return { + "bins": pgh.get_bin_centers(bins), + "counts": np.histogram(data.query(selection_string)[cal_energy_param], bins)[0], + "cut_counts": np.histogram( + data.query(f"(~{cut_field})&(~{pulser_field})")[cal_energy_param], + bins, + )[0], + "pulser_counts": np.histogram( + data.query(pulser_field)[cal_energy_param], + bins, + )[0], + } + + +def bin_survival_fraction( + data, + cal_energy_param, + selection_string, + cut_field="is_valid_cal", + pulser_field="is_pulser", + erange=(0, 3000), + dx=6, +): + counts_pass, bins_pass, _ = pgh.get_hist( + data.query(selection_string)[cal_energy_param], + bins=np.arange(erange[0], erange[1] + dx, dx), + ) + counts_fail, bins_fail, _ = pgh.get_hist( + data.query(f"(~{cut_field})&(~{pulser_field})")[cal_energy_param], + bins=np.arange(erange[0], erange[1] + dx, dx), + ) + sf = 100 * (counts_pass + 10 ** (-6)) / (counts_pass + counts_fail + 10 ** (-6)) + return {"bins": pgh.get_bin_centers(bins_pass), "sf": sf} def plot_baseline_timemap( @@ -115,9 +352,7 @@ def baseline_tracking_plots(files, lh5_path, plot_options=None): if plot_options is None: plot_options = {} plot_dict = {} - data = sto.read(lh5_path, files, field_mask=["bl_mean", "baseline", "timestamp"])[0].view_as( - "pd" - ) + data = lh5.read_as(lh5_path, files, "pd", field_mask=["bl_mean", "baseline", "timestamp"]) for key, item in plot_options.items(): if item["options"] is not None: plot_dict[key] = item["function"](data, **item["options"]) @@ -126,66 +361,72 @@ def baseline_tracking_plots(files, lh5_path, plot_options=None): return plot_dict -def energy_cal_th( - data: pd.Dataframe, - energy_params: list[str], - cal_energy_params: list | None = None, - selection_string: str = "", - hit_dict: dict | None = None, - cut_parameters: dict[str, int] | None = None, - plot_options: dict | None = None, - threshold: int = 0, - p_val: float = 0, - n_events: int | None = None, - final_cut_field: str = "is_valid_cal", - simplex: bool = True, - guess_keV: float | None = None, - tail_weight=100, - deg: int = 1, -) -> tuple(dict, dict, dict, dict): - data, hit_dict = apply_cuts( - data, - hit_dict if hit_dict is not None else {}, - cut_parameters if cut_parameters is not None else {}, - final_cut_field, - ) +def monitor_parameters(files, lh5_path, parameters): + data = lh5.read_as(lh5_path, files, "pd", field_mask=parameters) + out_dict = {} + for param in parameters: + mode, stdev = get_mode_stdev(data[param].to_numpy()) + out_dict[param] = {"mode": mode, "stdev": stdev} + return out_dict - if cal_energy_params is None: - cal_energy_params = [energy_param + "_cal" for energy_param in energy_params] - - results_dict = {} - plot_dict = {} - full_object_dict = {} - for energy_param, cal_energy_param in zip(energy_params, cal_energy_params): - full_object_dict[cal_energy_param] = calibrate_parameter( - energy_param, - selection_string, - plot_options, - guess_keV, - threshold, - p_val, - n_events, - simplex, - deg, - tail_weight=tail_weight, - ) - full_object_dict[cal_energy_param].calibrate_parameter(data) - results_dict[cal_energy_param] = full_object_dict[cal_energy_param].get_results_dict(data) - hit_dict.update(full_object_dict[cal_energy_param].hit_dict) - if ~np.isnan(full_object_dict[cal_energy_param].pars).all(): - plot_dict[cal_energy_param] = ( - full_object_dict[cal_energy_param].fill_plot_dict(data).copy() - ) - log.info("Finished all calibrations") - return hit_dict, results_dict, plot_dict, full_object_dict +def get_results_dict(ecal_class, data, cal_energy_param, selection_string): + if np.isnan(ecal_class.pars).all(): + return {} + else: + results_dict = copy.deepcopy(ecal_class.results["hpge_fit_energy_peaks_1"]) + + if "FWHMLinear" in results_dict: + fwhm_linear = results_dict["FWHMLinear"] + fwhm_linear["function"] = fwhm_linear["function"].__name__ + fwhm_linear["parameters"] = fwhm_linear["parameters"].to_dict() + fwhm_linear["uncertainties"] = fwhm_linear["uncertainties"].to_dict() + fwhm_linear["cov"] = fwhm_linear["cov"].tolist() + else: + fwhm_linear = None + + if "FWHMQuadratic" in results_dict: + fwhm_quad = results_dict["FWHMQuadratic"] + fwhm_quad["function"] = fwhm_quad["function"].__name__ + fwhm_quad["parameters"] = fwhm_quad["parameters"].to_dict() + fwhm_quad["uncertainties"] = fwhm_quad["uncertainties"].to_dict() + fwhm_quad["cov"] = fwhm_quad["cov"].tolist() + else: + fwhm_quad = None + + pk_dict = results_dict["peak_parameters"] + + for _, dic in pk_dict.items(): + dic["function"] = dic["function"].name + dic["parameters"] = dic["parameters"].to_dict() + dic["uncertainties"] = dic["uncertainties"].to_dict() + dic.pop("covariance") + + return { + "total_fep": len(data.query(f"{cal_energy_param}>2604&{cal_energy_param}<2624")), + "total_dep": len(data.query(f"{cal_energy_param}>1587&{cal_energy_param}<1597")), + "pass_fep": len( + data.query(f"{cal_energy_param}>2604&{cal_energy_param}<2624&{selection_string}") + ), + "pass_dep": len( + data.query(f"{cal_energy_param}>1587&{cal_energy_param}<1597&{selection_string}") + ), + "eres_linear": fwhm_linear, + "eres_quadratic": fwhm_quad, + "fitted_peaks": ecal_class.peaks_kev.tolist(), + "pk_fits": pk_dict, + } if __name__ == "__main__": argparser = argparse.ArgumentParser() - argparser.add_argument("--files", help="files", nargs="*", type=str) - argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=True) + argparser.add_argument("--files", help="filelist", nargs="*", type=str) + argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=False) + argparser.add_argument("--pulser_file", help="pulser_file", type=str, required=False) + argparser.add_argument("--ctc_dict", help="ctc_dict", nargs="*") + argparser.add_argument("--in_hit_dict", help="in_hit_dict", required=False) + argparser.add_argument("--inplot_dict", help="inplot_dict", required=False) argparser.add_argument("--configs", help="config", type=str, required=True) argparser.add_argument("--datatype", help="Datatype", type=str, required=True) @@ -193,6 +434,8 @@ def energy_cal_th( argparser.add_argument("--channel", help="Channel", type=str, required=True) argparser.add_argument("--tier", help="tier", type=str, default="hit") + argparser.add_argument("--metadata", help="metadata path", type=str, required=True) + argparser.add_argument("--log", help="log_file", type=str) argparser.add_argument("--plot_path", help="plot_path", type=str, required=False) @@ -206,10 +449,25 @@ def energy_cal_th( logging.getLogger("lgdo").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO) + logging.getLogger("legendmeta").setLevel(logging.INFO) + + meta = LegendMetadata(path=args.metadata) + chmap = meta.channelmap(args.timestamp) + + det_status = chmap.map("daq.rawid")[int(args.channel[2:])]["analysis"]["usability"] + + if args.in_hit_dict: + hit_dict = Props.read_from(args.in_hit_dict) + + db_files = [ + par_file + for par_file in args.ctc_dict + if os.path.splitext(par_file)[1] == ".json" or os.path.splitext(par_file)[1] == ".yml" + ] - database_dic = Props.read_from(args.ctc_dict) + database_dic = Props.read_from(db_files) - hit_dict = database_dic[args.channel]["ctc_params"] + hit_dict.update(database_dic[args.channel]["ctc_params"]) # get metadata dictionary configs = LegendMetadata(path=args.configs) @@ -233,40 +491,193 @@ def energy_cal_th( bl_plots[field]["function"] = eval(item["function"]) common_plots = kwarg_dict.pop("common_plots") + with open(args.files[0]) as f: + files = f.read().splitlines() + files = sorted(files) + # load data in data, threshold_mask = load_data( - args.files, + files, f"{args.channel}/dsp", hit_dict, - params=kwarg_dict["energy_params"] - + list(kwarg_dict["cut_parameters"]) - + ["timestamp", "trapTmax"], + params=[*kwarg_dict["energy_params"], kwarg_dict["cut_param"], "timestamp", "trapTmax"], threshold=kwarg_dict["threshold"], return_selection_mask=True, cal_energy_param="trapTmax", ) - # get pulser mask from tcm files - with open(args.tcm_filelist) as f: - tcm_files = f.read().splitlines() - tcm_files = sorted(np.unique(tcm_files)) - ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") - ) + if args.pulser_file: + with open(args.pulser_file) as f: + pulser_dict = json.load(f) + mask = np.array(pulser_dict["mask"]) + + elif args.tcm_filelist: + # get pulser mask from tcm files + with open(args.tcm_filelist) as f: + tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, kwarg_dict["pulser_multiplicity_threshold"] + ) + else: + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) + data["is_pulser"] = mask[threshold_mask] - # run energy calibration - out_dict, result_dict, plot_dict, ecal_object = energy_cal_th( - data, - hit_dict=hit_dict, - selection_string=f"({kwarg_dict['final_cut_field']})&(~is_pulser)", - **kwarg_dict, - ) + pk_pars = [ + (583.191, (20, 20), pgf.hpge_peak), + (727.330, (30, 30), pgf.hpge_peak), + (860.564, (30, 25), pgf.hpge_peak), + (1592.53, (40, 20), pgf.gauss_on_step), + (1620.50, (20, 40), pgf.gauss_on_step), + (2103.53, (40, 40), pgf.gauss_on_step), + (2614.553, (60, 60), pgf.hpge_peak), + ] + + glines = [pk_par[0] for pk_par in pk_pars] + + if "cal_energy_params" not in kwarg_dict: + cal_energy_params = [energy_param + "_cal" for energy_param in kwarg_dict["energy_params"]] + else: + cal_energy_params = kwarg_dict["cal_energy_params"] + + selection_string = f"~is_pulser&{kwarg_dict['cut_param']}" + + results_dict = {} + plot_dict = {} + full_object_dict = {} + + for energy_param, cal_energy_param in zip(kwarg_dict["energy_params"], cal_energy_params): + e_uncal = data.query(selection_string)[energy_param].to_numpy() + + hist, bins, bar = pgh.get_hist( + e_uncal[ + (e_uncal > np.nanpercentile(e_uncal, 95)) + & (e_uncal < np.nanpercentile(e_uncal, 99.9)) + ], + dx=1, + range=[np.nanpercentile(e_uncal, 95), np.nanpercentile(e_uncal, 99.9)], + ) + + guess = 2614.553 / bins[np.nanargmax(hist)] + full_object_dict[cal_energy_param] = HPGeCalibration( + energy_param, + glines, + guess, + kwarg_dict.get("deg", 0), + ) + full_object_dict[cal_energy_param].hpge_get_energy_peaks( + e_uncal, etol_kev=5 if det_status == "on" else 20 + ) + if 2614.553 not in full_object_dict[cal_energy_param].peaks_kev: + full_object_dict[cal_energy_param].hpge_get_energy_peaks( + e_uncal, peaks_kev=glines, etol_kev=5 if det_status == "on" else 30, n_sigma=2 + ) + got_peaks_kev = full_object_dict[cal_energy_param].peaks_kev.copy() + if det_status != "on": + full_object_dict[cal_energy_param].hpge_cal_energy_peak_tops( + e_uncal, + peaks_kev=got_peaks_kev, + update_cal_pars=True, + allowed_p_val=0, + ) + full_object_dict[cal_energy_param].hpge_fit_energy_peaks( + e_uncal, + peaks_kev=[2614.553], + peak_pars=pk_pars, + tail_weight=kwarg_dict.get("tail_weight", 0), + n_events=kwarg_dict.get("n_events", None), + allowed_p_val=kwarg_dict.get("p_val", 0), + update_cal_pars=bool(det_status == "on"), + bin_width_kev=0.5, + ) + full_object_dict[cal_energy_param].hpge_fit_energy_peaks( + e_uncal, + peaks_kev=got_peaks_kev, + peak_pars=pk_pars, + tail_weight=kwarg_dict.get("tail_weight", 0), + n_events=kwarg_dict.get("n_events", None), + allowed_p_val=kwarg_dict.get("p_val", 0), + update_cal_pars=False, + bin_width_kev=0.5, + ) + + full_object_dict[cal_energy_param].get_energy_res_curve( + FWHMLinear, + interp_energy_kev={"Qbb": 2039.0}, + ) + full_object_dict[cal_energy_param].get_energy_res_curve( + FWHMQuadratic, + interp_energy_kev={"Qbb": 2039.0}, + ) + + data[cal_energy_param] = nb_poly( + data[energy_param].to_numpy(), full_object_dict[cal_energy_param].pars + ) + + results_dict[cal_energy_param] = get_results_dict( + full_object_dict[cal_energy_param], data, cal_energy_param, selection_string + ) + + hit_dict.update({cal_energy_param: full_object_dict[cal_energy_param].gen_pars_dict()}) + if args.plot_path: + param_plot_dict = {} + if ~np.isnan(full_object_dict[cal_energy_param].pars).all(): + param_plot_dict["fwhm_fit"] = full_object_dict[cal_energy_param].plot_eres_fit( + e_uncal + ) + param_plot_dict["cal_fit"] = full_object_dict[cal_energy_param].plot_cal_fit( + e_uncal + ) + param_plot_dict["peak_fits"] = full_object_dict[cal_energy_param].plot_fits( + e_uncal + ) + + if "plot_options" in kwarg_dict: + for key, item in kwarg_dict["plot_options"].items(): + if item["options"] is not None: + param_plot_dict[key] = item["function"]( + data, + cal_energy_param, + selection_string, + **item["options"], + ) + else: + param_plot_dict[key] = item["function"]( + data, + cal_energy_param, + selection_string, + ) + plot_dict[cal_energy_param] = param_plot_dict + + for peak_dict in ( + full_object_dict[cal_energy_param] + .results["hpge_fit_energy_peaks_1"]["peak_parameters"] + .values() + ): + peak_dict["function"] = peak_dict["function"].name + peak_dict["parameters"] = peak_dict["parameters"].to_dict() + peak_dict["uncertainties"] = peak_dict["uncertainties"].to_dict() + for peak_dict in ( + full_object_dict[cal_energy_param] + .results["hpge_fit_energy_peaks"]["peak_parameters"] + .values() + ): + peak_dict["function"] = peak_dict["function"].name + peak_dict["parameters"] = peak_dict["parameters"].to_dict() + peak_dict["uncertainties"] = peak_dict["uncertainties"].to_dict() + + if "monitoring_parameters" in kwarg_dict: + monitor_dict = monitor_parameters( + files, f"{args.channel}/dsp", kwarg_dict["monitoring_parameters"] + ) + results_dict.update({"monitoring_parameters": monitor_dict}) # get baseline plots and save all plots to file if args.plot_path: common_dict = baseline_tracking_plots( - sorted(args.files), f"{args.channel}/dsp", plot_options=bl_plots + sorted(files), f"{args.channel}/dsp", plot_options=bl_plots ) for plot in list(common_dict): @@ -274,8 +685,6 @@ def energy_cal_th( plot_item = common_dict.pop(plot) plot_dict.update({plot: plot_item}) - pathlib.Path(os.path.dirname(args.plot_path)).mkdir(parents=True, exist_ok=True) - for key, item in plot_dict.items(): if isinstance(item, dict) and len(item) > 0: param_dict = {} @@ -283,14 +692,26 @@ def energy_cal_th( if plot in item: param_dict.update({plot: item[plot]}) common_dict.update({key: param_dict}) - plot_dict = {"ecal": plot_dict} - plot_dict["common"] = common_dict + if args.inplot_dict: + with open(args.inplot_dict, "rb") as f: + total_plot_dict = pkl.load(f) + else: + total_plot_dict = {} + + if "common" in total_plot_dict: + total_plot_dict["common"].update(common_dict) + else: + total_plot_dict["common"] = common_dict + + total_plot_dict.update({"ecal": plot_dict}) + + pathlib.Path(os.path.dirname(args.plot_path)).mkdir(parents=True, exist_ok=True) with open(args.plot_path, "wb") as f: - pkl.dump(plot_dict, f, protocol=pkl.HIGHEST_PROTOCOL) + pkl.dump(total_plot_dict, f, protocol=pkl.HIGHEST_PROTOCOL) # save output dictionary - output_dict = {"pars": out_dict, "results": {"ecal": result_dict}} + output_dict = {"pars": hit_dict, "results": {"ecal": results_dict}} with open(args.save_path, "w") as fp: pathlib.Path(os.path.dirname(args.save_path)).mkdir(parents=True, exist_ok=True) json.dump(output_dict, fp, indent=4) @@ -298,4 +719,4 @@ def energy_cal_th( # save calibration objects with open(args.results_path, "wb") as fp: pathlib.Path(os.path.dirname(args.results_path)).mkdir(parents=True, exist_ok=True) - pkl.dump({"ecal": ecal_object}, fp, protocol=pkl.HIGHEST_PROTOCOL) + pkl.dump({"ecal": full_object_dict}, fp, protocol=pkl.HIGHEST_PROTOCOL) diff --git a/scripts/pars_hit_lq.py b/scripts/pars_hit_lq.py index 3a43a45..ca4cd80 100644 --- a/scripts/pars_hit_lq.py +++ b/scripts/pars_hit_lq.py @@ -8,26 +8,53 @@ import pickle as pkl import warnings +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + import numpy as np import pandas as pd from legendmeta import LegendMetadata from legendmeta.catalog import Props -from pygama.math.peak_fitting import gauss_cdf +from pygama.math.distributions import gaussian +from pygama.pargen.data_cleaning import get_tcm_pulser_ids from pygama.pargen.lq_cal import * # noqa: F403 -from pygama.pargen.lq_cal import cal_lq -from pygama.pargen.utils import get_tcm_pulser_ids, load_data +from pygama.pargen.lq_cal import LQCal +from pygama.pargen.utils import load_data log = logging.getLogger(__name__) warnings.filterwarnings(action="ignore", category=RuntimeWarning) +def get_results_dict(lq_class): + return { + "cal_energy_param": lq_class.cal_energy_param, + "rt_correction": lq_class.dt_fit_pars, + # "cdf": lq_class.cdf.name, + "1590-1596keV": lq_class.timecorr_df.to_dict("index"), + "cut_value": lq_class.cut_val, + "sfs": lq_class.low_side_sf.to_dict("index"), + } + + +def fill_plot_dict(lq_class, data, plot_options, plot_dict=None): + if plot_dict is not None: + for key, item in plot_options.items(): + if item["options"] is not None: + plot_dict[key] = item["function"](lq_class, data, **item["options"]) + else: + plot_dict[key] = item["function"](lq_class, data) + else: + plot_dict = {} + return plot_dict + + def lq_calibration( data: pd.DataFrame, cal_dicts: dict, energy_param: str, cal_energy_param: str, eres_func: callable, - cdf: callable = gauss_cdf, + cdf: callable = gaussian, selection_string: str = "", plot_options: dict | None = None, ): @@ -62,17 +89,16 @@ def lq_calibration( A dict containing the results of the LQ calibration plot_dict: dict A dict containing all the figures specified by the plot options - lq: cal_lq class - The cal_lq object used for the LQ calibration + lq: LQCal class + The LQCal object used for the LQ calibration """ - lq = cal_lq( + lq = LQCal( cal_dicts, cal_energy_param, eres_func, cdf, selection_string, - plot_options, ) data["LQ_Ecorr"] = np.divide(data["lq80"], data[energy_param]) @@ -88,12 +114,14 @@ def lq_calibration( lq.calibrate(data, "LQ_Ecorr") log.info("Calibrated LQ") - return cal_dicts, lq.get_results_dict(), lq.fill_plot_dict(data), lq + return cal_dicts, get_results_dict(lq), fill_plot_dict(lq, data, plot_options), lq argparser = argparse.ArgumentParser() argparser.add_argument("files", help="files", nargs="*", type=str) -argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=True) +argparser.add_argument("--pulser_file", help="pulser_file", type=str, required=False) +argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=False) + argparser.add_argument("--ecal_file", help="ecal_file", type=str, required=True) argparser.add_argument("--eres_file", help="eres_file", type=str, required=True) argparser.add_argument("--inplots", help="in_plot_path", type=str, required=False) @@ -126,7 +154,7 @@ def lq_calibration( ecal_dict = Props.read_from(args.ecal_file) cal_dict = ecal_dict["pars"]["operations"] -eres_dict = ecal_dict["results"] +eres_dict = ecal_dict["results"]["ecal"] with open(args.eres_file, "rb") as o: object_dict = pkl.load(o) @@ -134,7 +162,7 @@ def lq_calibration( if kwarg_dict["run_lq"] is True: kwarg_dict.pop("run_lq") - cdf = eval(kwarg_dict.pop("cdf")) if "cdf" in kwarg_dict else gauss_cdf + cdf = eval(kwarg_dict.pop("cdf")) if "cdf" in kwarg_dict else gaussian if "plot_options" in kwarg_dict: for field, item in kwarg_dict["plot_options"].items(): @@ -173,13 +201,25 @@ def eres_func(x): return_selection_mask=True, ) - # get pulser mask from tcm files - with open(args.tcm_filelist) as f: - tcm_files = f.read().splitlines() - tcm_files = sorted(tcm_files) - ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") - ) + if args.pulser_file: + with open(args.pulser_file) as f: + pulser_dict = json.load(f) + mask = np.array(pulser_dict["mask"]) + if "pulser_multiplicity_threshold" in kwarg_dict: + kwarg_dict.pop("pulser_multiplicity_threshold") + + elif args.tcm_filelist: + # get pulser mask from tcm files + with open(args.tcm_filelist) as f: + tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") + ) + else: + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) + data["is_pulser"] = mask[threshold_mask] cal_dict, out_dict, plot_dict, obj = lq_calibration( diff --git a/scripts/pars_hit_qc.py b/scripts/pars_hit_qc.py new file mode 100644 index 0000000..c432d69 --- /dev/null +++ b/scripts/pars_hit_qc.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import pathlib +import pickle as pkl +import re +import warnings + +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + +import numpy as np +from legendmeta import LegendMetadata +from legendmeta.catalog import Props +from lgdo.lh5 import ls +from pygama.pargen.data_cleaning import ( + generate_cut_classifiers, + get_keys, + get_tcm_pulser_ids, +) +from pygama.pargen.utils import load_data + +log = logging.getLogger(__name__) + +warnings.filterwarnings(action="ignore", category=RuntimeWarning) + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--cal_files", help="cal_files", nargs="*", type=str) + argparser.add_argument("--fft_files", help="fft_files", nargs="*", type=str) + argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, required=False) + argparser.add_argument("--pulser_file", help="pulser_file", type=str, required=False) + + argparser.add_argument("--configs", help="config", type=str, required=True) + argparser.add_argument("--datatype", help="Datatype", type=str, required=True) + argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True) + argparser.add_argument("--channel", help="Channel", type=str, required=True) + argparser.add_argument("--tier", help="tier", type=str, default="hit") + + argparser.add_argument("--log", help="log_file", type=str) + + argparser.add_argument("--plot_path", help="plot_path", type=str, required=False) + argparser.add_argument("--save_path", help="save_path", type=str) + args = argparser.parse_args() + + logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") + logging.getLogger("numba").setLevel(logging.INFO) + logging.getLogger("parse").setLevel(logging.INFO) + logging.getLogger("lgdo").setLevel(logging.INFO) + logging.getLogger("h5py").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + logging.getLogger("legendmeta").setLevel(logging.INFO) + + # get metadata dictionary + configs = LegendMetadata(path=args.configs) + channel_dict = configs.on(args.timestamp, system=args.datatype)["snakemake_rules"] + channel_dict = channel_dict["pars_hit_qc"]["inputs"]["qc_config"][args.channel] + + kwarg_dict = Props.read_from(channel_dict) + + kwarg_dict_fft = kwarg_dict["fft_fields"] + if len(args.fft_files) > 0: + fft_fields = get_keys( + [ + key.replace(f"{args.channel}/dsp/", "") + for key in ls(args.fft_files[0], f"{args.channel}/dsp/") + ], + kwarg_dict_fft["cut_parameters"], + ) + + fft_data = load_data( + args.fft_files, + f"{args.channel}/dsp", + {}, + [*fft_fields, "timestamp", "trapTmax"], + ) + + discharges = fft_data["t_sat_lo"] > 0 + discharge_timestamps = np.where(fft_data["timestamp"][discharges])[0] + is_recovering = np.full(len(fft_data), False, dtype=bool) + for tstamp in discharge_timestamps: + is_recovering = is_recovering | np.where( + ( + ((fft_data["timestamp"] - tstamp) < 0.01) + & ((fft_data["timestamp"] - tstamp) > 0) + ), + True, + False, + ) + fft_data["is_recovering"] = is_recovering + + hit_dict_fft = {} + plot_dict_fft = {} + cut_data = fft_data.query("is_recovering==0") + log.debug(f"cut_data shape: {len(cut_data)}") + for name, cut in kwarg_dict_fft["cut_parameters"].items(): + cut_dict, cut_plots = generate_cut_classifiers( + cut_data, + {name: cut}, + kwarg_dict.get("rounding", 4), + display=1 if args.plot_path else 0, + ) + hit_dict_fft.update(cut_dict) + plot_dict_fft.update(cut_plots) + + log.debug(f"{name} calculated cut_dict is: {json.dumps(cut_dict, indent=2)}") + + ct_mask = np.full(len(cut_data), True, dtype=bool) + for outname, info in cut_dict.items(): + # convert to pandas eval + exp = info["expression"] + for key in info.get("parameters", None): + exp = re.sub(f"(? 0 + discharge_timestamps = np.where(data["timestamp"][discharges])[0] + is_recovering = np.full(len(data), False, dtype=bool) + for tstamp in discharge_timestamps: + is_recovering = is_recovering | np.where( + (((data["timestamp"] - tstamp) < 0.01) & ((data["timestamp"] - tstamp) > 0)), + True, + False, + ) + data["is_recovering"] = is_recovering + + rng = np.random.default_rng() + mask = np.full(len(data.query("~is_pulser & ~is_recovering")), False, dtype=bool) + mask[rng.choice(len(data.query("~is_pulser & ~is_recovering")), 4000, replace=False)] = True + + if "initial_cal_cuts" in kwarg_dict: + init_cal = kwarg_dict["initial_cal_cuts"] + hit_dict_init_cal, plot_dict_init_cal = generate_cut_classifiers( + data.query("~is_pulser & ~is_recovering")[mask], + init_cal["cut_parameters"], + init_cal.get("rounding", 4), + display=1 if args.plot_path else 0, + ) + ct_mask = np.full(len(data), True, dtype=bool) + for outname, info in hit_dict_init_cal.items(): + # convert to pandas eval + exp = info["expression"] + for key in info.get("parameters", None): + exp = re.sub(f"(? 500: + data = data.query("is_pulser & ~is_recovering") + else: + data = data.query("~is_pulser & ~is_recovering")[mask] + + hit_dict_cal, plot_dict_cal = generate_cut_classifiers( + data, + kwarg_dict_cal["cut_parameters"], + kwarg_dict.get("rounding", 4), + display=1 if args.plot_path else 0, + ) + + hit_dict = {**hit_dict_fft, **hit_dict_init_cal, **hit_dict_cal} + plot_dict = {**plot_dict_fft, **plot_dict_init_cal, **plot_dict_cal} + + pathlib.Path(os.path.dirname(args.save_path)).mkdir(parents=True, exist_ok=True) + with open(args.save_path, "w") as f: + json.dump(hit_dict, f, indent=4) + + if args.plot_path: + pathlib.Path(os.path.dirname(args.plot_path)).mkdir(parents=True, exist_ok=True) + with open(args.plot_path, "wb") as f: + pkl.dump({"qc": plot_dict}, f, protocol=pkl.HIGHEST_PROTOCOL) diff --git a/scripts/pars_pht_aoecal.py b/scripts/pars_pht_aoecal.py index a646857..30e1a9e 100644 --- a/scripts/pars_pht_aoecal.py +++ b/scripts/pars_pht_aoecal.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import copy import json import logging import os @@ -9,19 +10,54 @@ import warnings from typing import Callable +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + import numpy as np import pandas as pd from legendmeta import LegendMetadata from legendmeta.catalog import Props from pygama.pargen.AoE_cal import * # noqa: F403 -from pygama.pargen.AoE_cal import cal_aoe, pol1, sigma_fit, standard_aoe -from pygama.pargen.utils import get_tcm_pulser_ids, load_data +from pygama.pargen.AoE_cal import CalAoE, Pol1, SigmaFit, aoe_peak +from pygama.pargen.data_cleaning import get_tcm_pulser_ids +from pygama.pargen.utils import load_data from util.FileKey import ChannelProcKey, ProcessingFileKey log = logging.getLogger(__name__) warnings.filterwarnings(action="ignore", category=RuntimeWarning) +def get_results_dict(aoe_class): + result_dict = {} + for tstamp in aoe_class.low_side_sfs_by_run: + result_dict[tstamp] = { + "cal_energy_param": aoe_class.cal_energy_param, + "dt_param": aoe_class.dt_param, + "rt_correction": aoe_class.dt_corr, + "1000-1300keV": aoe_class.timecorr_df.to_dict("index"), + "correction_fit_results": aoe_class.energy_corr_res_dict, + "low_cut": aoe_class.low_cut_val, + "high_cut": aoe_class.high_cut_val, + "low_side_sfs": aoe_class.low_side_sfs.to_dict("index"), + "2_side_sfs": aoe_class.two_side_sfs.to_dict("index"), + "low_side_sfs_by_run": aoe_class.low_side_sfs_by_run[tstamp].to_dict("index"), + "2_side_sfs_by_run": aoe_class.two_side_sfs_by_run[tstamp].to_dict("index"), + } + return result_dict + + +def fill_plot_dict(aoe_class, data, plot_options, plot_dict=None): + if plot_dict is None: + plot_dict = {} + for key, item in plot_options.items(): + if item["options"] is not None: + plot_dict[key] = item["function"](aoe_class, data, **item["options"]) + else: + plot_dict[key] = item["function"](aoe_class, data) + + return plot_dict + + def aoe_calibration( data: pd.Dataframe, cal_dicts: dict, @@ -29,36 +65,34 @@ def aoe_calibration( energy_param: str, cal_energy_param: str, eres_func: Callable, - pdf: Callable = standard_aoe, + pdf: Callable = aoe_peak, selection_string: str = "", dt_corr: bool = False, dep_correct: bool = False, dt_cut: dict | None = None, high_cut_val: int = 3, - mean_func: Callable = pol1, - sigma_func: Callable = sigma_fit, - dep_acc: float = 0.9, + mean_func: Callable = Pol1, + sigma_func: Callable = SigmaFit, + # dep_acc: float = 0.9, dt_param: str = "dt_eff", comptBands_width: int = 20, plot_options: dict | None = None, ): data["AoE_Uncorr"] = data[current_param] / data[energy_param] - aoe = cal_aoe( - cal_dicts, - cal_energy_param, - eres_func, - pdf, - selection_string, - dt_corr, - dep_acc, - dep_correct, - dt_cut, - dt_param, - high_cut_val, - mean_func, - sigma_func, - comptBands_width, - plot_options if plot_options is not None else {}, + aoe = CalAoE( + cal_dicts=cal_dicts, + cal_energy_param=cal_energy_param, + eres_func=eres_func, + pdf=pdf, + selection_string=selection_string, + dt_corr=dt_corr, + dep_correct=dep_correct, + dt_cut=dt_cut, + dt_param=dt_param, + high_cut_val=high_cut_val, + mean_func=mean_func, + sigma_func=sigma_func, + compt_bands_width=comptBands_width, ) aoe.update_cal_dicts( { @@ -70,12 +104,13 @@ def aoe_calibration( ) aoe.calibrate(data, "AoE_Uncorr") log.info("Calibrated A/E") - return cal_dicts, aoe.get_results_dict(), aoe.fill_plot_dict(data), aoe + return cal_dicts, get_results_dict(aoe), fill_plot_dict(aoe, data, plot_options), aoe argparser = argparse.ArgumentParser() argparser.add_argument("--input_files", help="files", type=str, nargs="*", required=True) -argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, nargs="*", required=True) +argparser.add_argument("--pulser_files", help="pulser_file", nargs="*", type=str, required=False) +argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, nargs="*", required=False) argparser.add_argument("--ecal_file", help="ecal_file", type=str, nargs="*", required=True) argparser.add_argument("--eres_file", help="eres_file", type=str, nargs="*", required=True) argparser.add_argument("--inplots", help="eres_file", type=str, nargs="*", required=True) @@ -98,6 +133,7 @@ def aoe_calibration( logging.getLogger("lgdo").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("legendmeta").setLevel(logging.INFO) def run_splitter(files): @@ -225,20 +261,28 @@ def run_splitter(files): return_selection_mask=True, ) - # get pulser mask from tcm files - if isinstance(args.tcm_filelist, list): - tcm_files = [] - for file in args.tcm_filelist: + if args.pulser_files: + mask = np.array([], dtype=bool) + for file in args.pulser_files: with open(file) as f: - tcm_files += f.read().splitlines() - else: + pulser_dict = json.load(f) + pulser_mask = np.array(pulser_dict["mask"]) + mask = np.append(mask, pulser_mask) + if "pulser_multiplicity_threshold" in kwarg_dict: + kwarg_dict.pop("pulser_multiplicity_threshold") + + elif args.tcm_filelist: + # get pulser mask from tcm files with open(args.tcm_filelist) as f: tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, kwarg_dict["pulser_multiplicity_threshold"] + ) + else: + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) - tcm_files = sorted(np.unique(tcm_files)) - ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") - ) data["is_pulser"] = mask[threshold_mask] for tstamp in cal_dict: @@ -248,19 +292,18 @@ def run_splitter(files): row = pd.DataFrame(row) data = pd.concat([data, row]) - pdf = eval(kwarg_dict.pop("pdf")) if "pdf" in kwarg_dict else standard_aoe + pdf = eval(kwarg_dict.pop("pdf")) if "pdf" in kwarg_dict else aoe_peak - mean_func = eval(kwarg_dict.pop("mean_func")) if "mean_func" in kwarg_dict else pol1 + mean_func = eval(kwarg_dict.pop("mean_func")) if "mean_func" in kwarg_dict else Pol1 - if "sigma_func" in kwarg_dict: - sigma_func = eval(kwarg_dict.pop("sigma_func")) - else: - sigma_func = sigma_fit + sigma_func = eval(kwarg_dict.pop("sigma_func")) if "sigma_func" in kwarg_dict else SigmaFit try: - eres = results_dicts[next(iter(results_dicts))]["partition_ecal"][ - kwarg_dict["cal_energy_param"] - ]["eres_linear"].copy() + eres = copy.deepcopy( + results_dicts[next(iter(results_dicts))]["partition_ecal"][ + kwarg_dict["cal_energy_param"] + ]["eres_linear"] + ) def eres_func(x): return eval(eres["expression"], dict(x=x, **eres["parameters"])) @@ -269,9 +312,11 @@ def eres_func(x): raise RuntimeError except (KeyError, RuntimeError): try: - eres = results_dicts[next(iter(results_dicts))]["ecal"][ - kwarg_dict["cal_energy_param"] - ]["eres_linear"].copy() + eres = copy.deepcopy( + results_dicts[next(iter(results_dicts))]["ecal"][kwarg_dict["cal_energy_param"]][ + "eres_linear" + ] + ) def eres_func(x): return eval(eres["expression"], dict(x=x, **eres["parameters"])) @@ -291,16 +336,16 @@ def eres_func(x): sigma_func=sigma_func, **kwarg_dict, ) - + aoe_obj.pdf = aoe_obj.pdf.name # need to change eres func as can't pickle lambdas try: aoe_obj.eres_func = results_dicts[next(iter(results_dicts))]["partition_ecal"][ kwarg_dict["cal_energy_param"] - ]["eres_linear"].copy() + ]["eres_linear"] except KeyError: aoe_obj.eres_func = {} else: - out_dict = {} + out_dict = {tstamp: None for tstamp in cal_dict} plot_dict = {} aoe_obj = None @@ -345,7 +390,7 @@ def eres_func(x): "pars": {"operations": cal_dict[fk.timestamp]}, "results": dict( **results_dicts[fk.timestamp], - aoe=out_dict, + aoe=out_dict[fk.timestamp], ), } pathlib.Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True) diff --git a/scripts/pars_pht_lqcal.py b/scripts/pars_pht_lqcal.py index 2d1bc06..c5ba80b 100644 --- a/scripts/pars_pht_lqcal.py +++ b/scripts/pars_pht_lqcal.py @@ -8,27 +8,54 @@ import pickle as pkl import warnings +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + import numpy as np import pandas as pd from legendmeta import LegendMetadata from legendmeta.catalog import Props -from pygama.math.peak_fitting import gauss_cdf +from pygama.math.distributions import gaussian +from pygama.pargen.data_cleaning import get_tcm_pulser_ids from pygama.pargen.lq_cal import * # noqa: F403 -from pygama.pargen.lq_cal import cal_lq -from pygama.pargen.utils import get_tcm_pulser_ids, load_data +from pygama.pargen.lq_cal import LQCal +from pygama.pargen.utils import load_data from util.FileKey import ChannelProcKey, ProcessingFileKey log = logging.getLogger(__name__) warnings.filterwarnings(action="ignore", category=RuntimeWarning) +def get_results_dict(lq_class): + return { + "cal_energy_param": lq_class.cal_energy_param, + "rt_correction": lq_class.dt_fit_pars, + # "cdf": lq_class.cdf.name, + "1590-1596keV": lq_class.timecorr_df.to_dict("index"), + "cut_value": lq_class.cut_val, + "sfs": lq_class.low_side_sf.to_dict("index"), + } + + +def fill_plot_dict(lq_class, data, plot_options, plot_dict=None): + if plot_dict is None: + plot_dict = {} + for key, item in plot_options.items(): + if item["options"] is not None: + plot_dict[key] = item["function"](lq_class, data, **item["options"]) + else: + plot_dict[key] = item["function"](lq_class, data) + + return plot_dict + + def lq_calibration( data: pd.DataFrame, cal_dicts: dict, energy_param: str, cal_energy_param: str, eres_func: callable, - cdf: callable = gauss_cdf, + cdf: callable = gaussian, selection_string: str = "", plot_options: dict | None = None, ): @@ -66,13 +93,12 @@ def lq_calibration( The cal_lq object used for the LQ calibration """ - lq = cal_lq( + lq = LQCal( cal_dicts, cal_energy_param, eres_func, cdf, selection_string, - plot_options, ) data["LQ_Ecorr"] = np.divide(data["lq80"], data[energy_param]) @@ -88,12 +114,13 @@ def lq_calibration( lq.calibrate(data, "LQ_Ecorr") log.info("Calibrated LQ") - return cal_dicts, lq.get_results_dict(), lq.fill_plot_dict(data), lq + return cal_dicts, get_results_dict(lq), fill_plot_dict(lq, data, plot_options), lq argparser = argparse.ArgumentParser() argparser.add_argument("--input_files", help="files", type=str, nargs="*", required=True) -argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, nargs="*", required=True) +argparser.add_argument("--pulser_files", help="pulser_file", type=str, nargs="*", required=False) +argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, nargs="*", required=False) argparser.add_argument("--ecal_file", help="ecal_file", type=str, nargs="*", required=True) argparser.add_argument("--eres_file", help="eres_file", type=str, nargs="*", required=True) argparser.add_argument("--inplots", help="eres_file", type=str, nargs="*", required=True) @@ -116,6 +143,7 @@ def lq_calibration( logging.getLogger("lgdo").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("legendmeta").setLevel(logging.INFO) def run_splitter(files): @@ -232,20 +260,28 @@ def run_splitter(files): return_selection_mask=True, ) - # get pulser mask from tcm files - if isinstance(args.tcm_filelist, list): - tcm_files = [] - for file in args.tcm_filelist: + if args.pulser_files: + mask = np.array([], dtype=bool) + for file in args.pulser_files: with open(file) as f: - tcm_files += f.read().splitlines() - else: + pulser_dict = json.load(f) + pulser_mask = np.array(pulser_dict["mask"]) + mask = np.append(mask, pulser_mask) + if "pulser_multiplicity_threshold" in kwarg_dict: + kwarg_dict.pop("pulser_multiplicity_threshold") + + elif args.tcm_filelist: + # get pulser mask from tcm files with open(args.tcm_filelist) as f: tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, kwarg_dict["pulser_multiplicity_threshold"] + ) + else: + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) - tcm_files = sorted(np.unique(tcm_files)) - ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") - ) data["is_pulser"] = mask[threshold_mask] for tstamp in cal_dict: @@ -255,7 +291,7 @@ def run_splitter(files): row = pd.DataFrame(row) data = pd.concat([data, row]) - cdf = eval(kwarg_dict.pop("cdf")) if "cdf" in kwarg_dict else gauss_cdf + cdf = eval(kwarg_dict.pop("cdf")) if "cdf" in kwarg_dict else gaussian try: eres = results_dicts[next(iter(results_dicts))]["partition_ecal"][ diff --git a/scripts/pars_pht_partcal.py b/scripts/pars_pht_partcal.py index a148946..21a2654 100644 --- a/scripts/pars_pht_partcal.py +++ b/scripts/pars_pht_partcal.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import copy import json import logging import os @@ -9,95 +10,25 @@ import re import warnings +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + import numpy as np import pandas as pd +import pygama.math.distributions as pgf +import pygama.math.histogram as pgh from legendmeta import LegendMetadata from legendmeta.catalog import Props -from pygama.pargen.ecal_th import * # noqa: F403 -from pygama.pargen.ecal_th import high_stats_fitting -from pygama.pargen.utils import get_tcm_pulser_ids, load_data +from pygama.math.distributions import nb_poly +from pygama.pargen.data_cleaning import get_tcm_pulser_ids +from pygama.pargen.energy_cal import FWHMLinear, FWHMQuadratic, HPGeCalibration +from pygama.pargen.utils import load_data from util.FileKey import ChannelProcKey, ProcessingFileKey log = logging.getLogger(__name__) warnings.filterwarnings(action="ignore", category=RuntimeWarning) -def update_cal_dicts(cal_dicts, update_dict): - if re.match(r"(\d{8})T(\d{6})Z", next(iter(cal_dicts))): - for tstamp in cal_dicts: - if tstamp in update_dict: - cal_dicts[tstamp].update(update_dict[tstamp]) - else: - cal_dicts[tstamp].update(update_dict) - else: - cal_dicts.update(update_dict) - return cal_dicts - - -def partition_energy_cal_th( - data: pd.Datframe, - hit_dicts: dict, - energy_params: list[str], - selection_string: str = "", - threshold: int = 0, - p_val: float = 0, - plot_options: dict | None = None, - simplex: bool = True, - tail_weight: int = 20, - # cal_energy_params: list = None, - # deg:int=2, -) -> tuple(dict, dict, dict, dict): - results_dict = {} - plot_dict = {} - full_object_dict = {} - # if cal_energy_params is None: - # cal_energy_params = [energy_param + "_cal" for energy_param in energy_params] - for energy_param in energy_params: - full_object_dict[energy_param] = high_stats_fitting( - energy_param=energy_param, - selection_string=selection_string, - threshold=threshold, - p_val=p_val, - plot_options=plot_options, - simplex=simplex, - tail_weight=tail_weight, - ) - full_object_dict[energy_param].fit_peaks(data) - results_dict[energy_param] = full_object_dict[energy_param].get_results_dict(data) - if full_object_dict[energy_param].results: - plot_dict[energy_param] = full_object_dict[energy_param].fill_plot_dict(data).copy() - - log.info("Finished all calibrations") - return hit_dicts, results_dict, plot_dict, full_object_dict - - -argparser = argparse.ArgumentParser() -argparser.add_argument("--input_files", help="files", type=str, nargs="*", required=True) -argparser.add_argument("--tcm_filelist", help="tcm_filelist", type=str, nargs="*", required=True) -argparser.add_argument("--ecal_file", help="ecal_file", type=str, nargs="*", required=True) -argparser.add_argument("--eres_file", help="eres_file", type=str, nargs="*", required=True) -argparser.add_argument("--inplots", help="eres_file", type=str, nargs="*", required=True) - -argparser.add_argument("--configs", help="configs", type=str, required=True) -argparser.add_argument("--timestamp", help="Datatype", type=str, required=True) -argparser.add_argument("--datatype", help="Datatype", type=str, required=True) -argparser.add_argument("--channel", help="Channel", type=str, required=True) - -argparser.add_argument("--log", help="log_file", type=str) - -argparser.add_argument("--plot_file", help="plot_file", type=str, nargs="*", required=False) -argparser.add_argument("--hit_pars", help="hit_pars", nargs="*", type=str) -argparser.add_argument("--fit_results", help="fit_results", nargs="*", type=str) -args = argparser.parse_args() - -logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") -logging.getLogger("numba").setLevel(logging.INFO) -logging.getLogger("parse").setLevel(logging.INFO) -logging.getLogger("lgdo").setLevel(logging.INFO) -logging.getLogger("h5py").setLevel(logging.INFO) -logging.getLogger("matplotlib").setLevel(logging.INFO) - - def run_splitter(files): """ Returns list containing lists of each run @@ -116,186 +47,419 @@ def run_splitter(files): return run_files -configs = LegendMetadata(path=args.configs) -channel_dict = configs.on(args.timestamp, system=args.datatype)["snakemake_rules"][ - "pars_pht_partcal" -]["inputs"]["pars_pht_partcal_config"][args.channel] +def update_cal_dicts(cal_dicts, update_dict): + if re.match(r"(\d{8})T(\d{6})Z", next(iter(cal_dicts))): + for tstamp in cal_dicts: + if tstamp in update_dict: + cal_dicts[tstamp].update(update_dict[tstamp]) + else: + cal_dicts[tstamp].update(update_dict) + else: + cal_dicts.update(update_dict) + return cal_dicts + + +def bin_spectrum( + data, + cal_energy_param, + selection_string, + cut_field="is_valid_cal", + pulser_field="is_pulser", + erange=(0, 3000), + dx=2, +): + bins = np.arange(erange[0], erange[1] + dx, dx) + return { + "bins": pgh.get_bin_centers(bins), + "counts": np.histogram(data.query(selection_string)[cal_energy_param], bins)[0], + "cut_counts": np.histogram( + data.query(f"(~{cut_field})&(~{pulser_field})")[cal_energy_param], + bins, + )[0], + "pulser_counts": np.histogram( + data.query(pulser_field)[cal_energy_param], + bins, + )[0], + } -kwarg_dict = Props.read_from(channel_dict) -cal_dict = {} -results_dicts = {} -if isinstance(args.ecal_file, list): - for ecal in args.ecal_file: - cal = Props.read_from(ecal) +def get_results_dict(ecal_class, data, cal_energy_param, selection_string): + if np.isnan(ecal_class.pars).all(): + return {} + else: + results_dict = copy.deepcopy(ecal_class.results["hpge_fit_energy_peaks"]) + + if "FWHMLinear" in results_dict: + fwhm_linear = results_dict["FWHMLinear"] + fwhm_linear["function"] = fwhm_linear["function"].__name__ + fwhm_linear["parameters"] = fwhm_linear["parameters"].to_dict() + fwhm_linear["uncertainties"] = fwhm_linear["uncertainties"].to_dict() + fwhm_linear["cov"] = fwhm_linear["cov"].tolist() + else: + fwhm_linear = None + + if "FWHMQuadratic" in results_dict: + fwhm_quad = results_dict["FWHMQuadratic"] + fwhm_quad["function"] = fwhm_quad["function"].__name__ + fwhm_quad["parameters"] = fwhm_quad["parameters"].to_dict() + fwhm_quad["uncertainties"] = fwhm_quad["uncertainties"].to_dict() + fwhm_quad["cov"] = fwhm_quad["cov"].tolist() + else: + fwhm_quad = None + + pk_dict = results_dict["peak_parameters"] + + for _, dic in pk_dict.items(): + dic["function"] = dic["function"].name + dic["parameters"] = dic["parameters"].to_dict() + dic["uncertainties"] = dic["uncertainties"].to_dict() + dic.pop("covariance") + + return { + "total_fep": len(data.query(f"{cal_energy_param}>2604&{cal_energy_param}<2624")), + "total_dep": len(data.query(f"{cal_energy_param}>1587&{cal_energy_param}<1597")), + "pass_fep": len( + data.query(f"{cal_energy_param}>2604&{cal_energy_param}<2624&{selection_string}") + ), + "pass_dep": len( + data.query(f"{cal_energy_param}>1587&{cal_energy_param}<1597&{selection_string}") + ), + "eres_linear": fwhm_linear, + "eres_quadratic": fwhm_quad, + "fitted_peaks": ecal_class.peaks_kev.tolist(), + "pk_fits": pk_dict, + "peak_param": results_dict["peak_param"], + } + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--input_files", help="files", type=str, nargs="*", required=True) + argparser.add_argument( + "--pulser_files", help="pulser_file", nargs="*", type=str, required=False + ) + argparser.add_argument( + "--tcm_filelist", help="tcm_filelist", type=str, nargs="*", required=False + ) + argparser.add_argument("--ecal_file", help="ecal_file", type=str, nargs="*", required=True) + argparser.add_argument("--eres_file", help="eres_file", type=str, nargs="*", required=True) + argparser.add_argument("--inplots", help="eres_file", type=str, nargs="*", required=True) + + argparser.add_argument("--configs", help="configs", type=str, required=True) + argparser.add_argument("--timestamp", help="Datatype", type=str, required=True) + argparser.add_argument("--datatype", help="Datatype", type=str, required=True) + argparser.add_argument("--channel", help="Channel", type=str, required=True) + + argparser.add_argument("--log", help="log_file", type=str) + argparser.add_argument("--metadata", help="metadata path", type=str, required=True) + + argparser.add_argument("--plot_file", help="plot_file", type=str, nargs="*", required=False) + argparser.add_argument("--hit_pars", help="hit_pars", nargs="*", type=str) + argparser.add_argument("--fit_results", help="fit_results", nargs="*", type=str) + args = argparser.parse_args() + + logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") + logging.getLogger("numba").setLevel(logging.INFO) + logging.getLogger("parse").setLevel(logging.INFO) + logging.getLogger("lgdo").setLevel(logging.INFO) + logging.getLogger("h5py").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + logging.getLogger("legendmeta").setLevel(logging.INFO) + + meta = LegendMetadata(path=args.metadata) + chmap = meta.channelmap(args.timestamp) + + det_status = chmap.map("daq.rawid")[int(args.channel[2:])]["analysis"]["usability"] + + configs = LegendMetadata(path=args.configs) + channel_dict = configs.on(args.timestamp, system=args.datatype)["snakemake_rules"][ + "pars_pht_partcal" + ]["inputs"]["pars_pht_partcal_config"][args.channel] + + kwarg_dict = Props.read_from(channel_dict) + + cal_dict = {} + results_dicts = {} + if isinstance(args.ecal_file, list): + for ecal in args.ecal_file: + cal = Props.read_from(ecal) - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(ecal)) + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(ecal)) + cal_dict[fk.timestamp] = cal["pars"] + results_dicts[fk.timestamp] = cal["results"] + else: + cal = Props.read_from(args.ecal_file) + + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.ecal_file)) cal_dict[fk.timestamp] = cal["pars"] results_dicts[fk.timestamp] = cal["results"] -else: - cal = Props.read_from(args.ecal_file) - - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.ecal_file)) - cal_dict[fk.timestamp] = cal["pars"] - results_dicts[fk.timestamp] = cal["results"] -object_dict = {} -if isinstance(args.eres_file, list): - for ecal in args.eres_file: - with open(ecal, "rb") as o: - cal = pkl.load(o) - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(ecal)) - object_dict[fk.timestamp] = cal -else: - with open(args.eres_file, "rb") as o: - cal = pkl.load(o) - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.eres_file)) - object_dict[fk.timestamp] = cal - -inplots_dict = {} -if args.inplots: - if isinstance(args.inplots, list): - for ecal in args.inplots: + object_dict = {} + if isinstance(args.eres_file, list): + for ecal in args.eres_file: with open(ecal, "rb") as o: cal = pkl.load(o) fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(ecal)) - inplots_dict[fk.timestamp] = cal + object_dict[fk.timestamp] = cal else: - with open(args.inplots, "rb") as o: + with open(args.eres_file, "rb") as o: cal = pkl.load(o) - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.inplots)) - inplots_dict[fk.timestamp] = cal - - -if "plot_options" in kwarg_dict: - for field, item in kwarg_dict["plot_options"].items(): - kwarg_dict["plot_options"][field]["function"] = eval(item["function"]) - - -# sort files in dictionary where keys are first timestamp from run -if isinstance(args.input_files, list): - files = [] - for file in args.input_files: - with open(file) as f: - files += f.read().splitlines() -else: - with open(args.input_files) as f: - files = f.read().splitlines() - -files = sorted( - np.unique(files) -) # need this as sometimes files get double counted as it somehow puts in the p%-* filelist and individual runs also - -final_dict = {} -all_file = run_splitter(sorted(files)) -for filelist in all_file: - fk = ProcessingFileKey.get_filekey_from_pattern(os.path.basename(sorted(filelist)[0])) - timestamp = fk.timestamp - final_dict[timestamp] = sorted(filelist) - -params = [ - kwarg_dict["final_cut_field"], - "timestamp", -] -params += kwarg_dict["energy_params"] - -# load data in -data, threshold_mask = load_data( - final_dict, - f"{args.channel}/dsp", - cal_dict, - params=params, - threshold=kwarg_dict["threshold"], - return_selection_mask=True, - cal_energy_param=kwarg_dict["energy_params"][0], -) - -# get pulser mask from tcm files -if isinstance(args.tcm_filelist, list): - tcm_files = [] - for file in args.tcm_filelist: - with open(file) as f: - tcm_files += f.read().splitlines() -else: - with open(args.tcm_filelist) as f: - tcm_files = f.read().splitlines() - -tcm_files = sorted(np.unique(tcm_files)) -ids, mask = get_tcm_pulser_ids( - tcm_files, args.channel, kwarg_dict.pop("pulser_multiplicity_threshold") -) -data["is_pulser"] = mask[threshold_mask] - -for tstamp in cal_dict: - if tstamp not in np.unique(data["run_timestamp"]): - row = {key: [False] if data.dtypes[key] == "bool" else [np.nan] for key in data} - row["run_timestamp"] = tstamp - row = pd.DataFrame(row) - data = pd.concat([data, row]) - -# run energy supercal -hit_dicts, ecal_results, plot_dict, ecal_obj = partition_energy_cal_th( - data, - cal_dict, - selection_string=f"{kwarg_dict.pop('final_cut_field')}&(~is_pulser)", - **kwarg_dict, -) + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.eres_file)) + object_dict[fk.timestamp] = cal + + inplots_dict = {} + if args.inplots: + if isinstance(args.inplots, list): + for ecal in args.inplots: + with open(ecal, "rb") as o: + cal = pkl.load(o) + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(ecal)) + inplots_dict[fk.timestamp] = cal + else: + with open(args.inplots, "rb") as o: + cal = pkl.load(o) + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.inplots)) + inplots_dict[fk.timestamp] = cal + + if "plot_options" in kwarg_dict: + for field, item in kwarg_dict["plot_options"].items(): + kwarg_dict["plot_options"][field]["function"] = eval(item["function"]) + + # sort files in dictionary where keys are first timestamp from run + if isinstance(args.input_files, list): + files = [] + for file in args.input_files: + with open(file) as f: + files += f.read().splitlines() + else: + with open(args.input_files) as f: + files = f.read().splitlines() + + files = sorted( + np.unique(files) + ) # need this as sometimes files get double counted as it somehow puts in the p%-* filelist and individual runs also + + final_dict = {} + all_file = run_splitter(sorted(files)) + for filelist in all_file: + fk = ProcessingFileKey.get_filekey_from_pattern(os.path.basename(sorted(filelist)[0])) + timestamp = fk.timestamp + final_dict[timestamp] = sorted(filelist) + + params = [ + kwarg_dict["final_cut_field"], + "timestamp", + ] + params += kwarg_dict["energy_params"] + + # load data in + data, threshold_mask = load_data( + final_dict, + f"{args.channel}/dsp", + cal_dict, + params=params, + threshold=kwarg_dict["threshold"], + return_selection_mask=True, + cal_energy_param=kwarg_dict["energy_params"][0], + ) + + if args.pulser_files: + mask = np.array([], dtype=bool) + for file in args.pulser_files: + with open(file) as f: + pulser_dict = json.load(f) + pulser_mask = np.array(pulser_dict["mask"]) + mask = np.append(mask, pulser_mask) + if "pulser_multiplicity_threshold" in kwarg_dict: + kwarg_dict.pop("pulser_multiplicity_threshold") + + elif args.tcm_filelist: + # get pulser mask from tcm files + with open(args.tcm_filelist) as f: + tcm_files = f.read().splitlines() + tcm_files = sorted(np.unique(tcm_files)) + ids, mask = get_tcm_pulser_ids( + tcm_files, args.channel, kwarg_dict["pulser_multiplicity_threshold"] + ) + else: + msg = "No pulser file or tcm filelist provided" + raise ValueError(msg) + + data["is_pulser"] = mask[threshold_mask] + + for tstamp in cal_dict: + if tstamp not in np.unique(data["run_timestamp"]): + row = {key: [False] if data.dtypes[key] == "bool" else [np.nan] for key in data} + row["run_timestamp"] = tstamp + row = pd.DataFrame(row) + data = pd.concat([data, row]) + + pk_pars = [ + (238.632, (10, 10), pgf.gauss_on_step), + (511, (30, 30), pgf.gauss_on_step), + (583.191, (30, 30), pgf.hpge_peak), + (727.330, (30, 30), pgf.hpge_peak), + (763, (30, 15), pgf.gauss_on_step), + (785, (15, 30), pgf.gauss_on_step), + (860.564, (30, 25), pgf.hpge_peak), + (893, (25, 30), pgf.gauss_on_step), + (1079, (30, 30), pgf.gauss_on_step), + (1513, (30, 30), pgf.gauss_on_step), + (1592.53, (30, 20), pgf.hpge_peak), + (1620.50, (20, 30), pgf.hpge_peak), + (2103.53, (30, 30), pgf.hpge_peak), + (2614.553, (30, 30), pgf.hpge_peak), + (3125, (30, 30), pgf.gauss_on_step), + (3198, (30, 30), pgf.gauss_on_step), + (3474, (30, 30), pgf.gauss_on_step), + ] + + glines = [pk_par[0] for pk_par in pk_pars] + + if "cal_energy_params" not in kwarg_dict: + cal_energy_params = [energy_param + "_cal" for energy_param in kwarg_dict["energy_params"]] + else: + cal_energy_params = kwarg_dict["cal_energy_params"] + + selection_string = f"~is_pulser&{kwarg_dict['final_cut_field']}" + + ecal_results = {} + plot_dict = {} + full_object_dict = {} + + for energy_param, cal_energy_param in zip(kwarg_dict["energy_params"], cal_energy_params): + energy = data.query(selection_string)[energy_param].to_numpy() + full_object_dict[cal_energy_param] = HPGeCalibration( + energy_param, glines, 1, kwarg_dict.get("deg", 0) # , fixed={1: 1} + ) + full_object_dict[cal_energy_param].hpge_get_energy_peaks( + energy, etol_kev=5 if det_status == "on" else 10 + ) -if args.plot_file: - common_dict = plot_dict.pop("common") if "common" in list(plot_dict) else None + if det_status != "on": + full_object_dict[cal_energy_param].hpge_cal_energy_peak_tops( + energy, + update_cal_pars=True, + allowed_p_val=0, + ) + + full_object_dict[cal_energy_param].hpge_fit_energy_peaks( + energy, + peak_pars=pk_pars, + tail_weight=kwarg_dict.get("tail_weight", 0), + n_events=kwarg_dict.get("n_events", None), + allowed_p_val=kwarg_dict.get("p_val", 0), + update_cal_pars=bool(det_status == "on"), + bin_width_kev=0.25, + ) - if isinstance(args.plot_file, list): - for plot_file in args.plot_file: - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(plot_file)) + full_object_dict[cal_energy_param].get_energy_res_curve( + FWHMLinear, + interp_energy_kev={"Qbb": 2039.0}, + ) + full_object_dict[cal_energy_param].get_energy_res_curve( + FWHMQuadratic, + interp_energy_kev={"Qbb": 2039.0}, + ) + + data[cal_energy_param] = nb_poly( + data[energy_param].to_numpy(), full_object_dict[cal_energy_param].pars + ) + + ecal_results[cal_energy_param] = get_results_dict( + full_object_dict[cal_energy_param], data, cal_energy_param, selection_string + ) + cal_dict = update_cal_dicts( + cal_dict, {cal_energy_param: full_object_dict[cal_energy_param].gen_pars_dict()} + ) + + if args.plot_file: + param_plot_dict = {} + if ~np.isnan(full_object_dict[cal_energy_param].pars).all(): + param_plot_dict["fwhm_fit"] = full_object_dict[cal_energy_param].plot_eres_fit( + energy + ) + param_plot_dict["cal_fit"] = full_object_dict[cal_energy_param].plot_cal_fit( + energy + ) + param_plot_dict["peak_fits"] = full_object_dict[cal_energy_param].plot_fits( + energy, ncols=4, nrows=5 + ) + + if "plot_options" in kwarg_dict: + for key, item in kwarg_dict["plot_options"].items(): + if item["options"] is not None: + param_plot_dict[key] = item["function"]( + data, + cal_energy_param, + selection_string, + **item["options"], + ) + else: + param_plot_dict[key] = item["function"]( + data, + cal_energy_param, + selection_string, + ) + plot_dict[cal_energy_param] = param_plot_dict + + for peak_dict in ( + full_object_dict[cal_energy_param] + .results["hpge_fit_energy_peaks"]["peak_parameters"] + .values() + ): + peak_dict["function"] = peak_dict["function"].name + peak_dict["parameters"] = peak_dict["parameters"].to_dict() + peak_dict["uncertainties"] = peak_dict["uncertainties"].to_dict() + + if args.plot_file: + common_dict = plot_dict.pop("common") if "common" in list(plot_dict) else None + + if isinstance(args.plot_file, list): + for plot_file in args.plot_file: + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(plot_file)) + if args.inplots: + out_plot_dict = inplots_dict[fk.timestamp] + out_plot_dict.update({"partition_ecal": plot_dict}) + else: + out_plot_dict = {"partition_ecal": plot_dict} + + if "common" in list(out_plot_dict) and common_dict is not None: + out_plot_dict["common"].update(common_dict) + elif common_dict is not None: + out_plot_dict["common"] = common_dict + + pathlib.Path(os.path.dirname(plot_file)).mkdir(parents=True, exist_ok=True) + with open(plot_file, "wb") as w: + pkl.dump(out_plot_dict, w, protocol=pkl.HIGHEST_PROTOCOL) + else: if args.inplots: + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.plot_file)) out_plot_dict = inplots_dict[fk.timestamp] out_plot_dict.update({"partition_ecal": plot_dict}) else: out_plot_dict = {"partition_ecal": plot_dict} - if "common" in list(out_plot_dict) and common_dict is not None: out_plot_dict["common"].update(common_dict) elif common_dict is not None: out_plot_dict["common"] = common_dict - - pathlib.Path(os.path.dirname(plot_file)).mkdir(parents=True, exist_ok=True) - with open(plot_file, "wb") as w: + pathlib.Path(os.path.dirname(args.plot_file)).mkdir(parents=True, exist_ok=True) + with open(args.plot_file, "wb") as w: pkl.dump(out_plot_dict, w, protocol=pkl.HIGHEST_PROTOCOL) - else: - if args.inplots: - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(args.plot_file)) - out_plot_dict = inplots_dict[fk.timestamp] - out_plot_dict.update({"partition_ecal": plot_dict}) - else: - out_plot_dict = {"partition_ecal": plot_dict} - if "common" in list(out_plot_dict) and common_dict is not None: - out_plot_dict["common"].update(common_dict) - elif common_dict is not None: - out_plot_dict["common"] = common_dict - pathlib.Path(os.path.dirname(args.plot_file)).mkdir(parents=True, exist_ok=True) - with open(args.plot_file, "wb") as w: - pkl.dump(out_plot_dict, w, protocol=pkl.HIGHEST_PROTOCOL) - - -for out in sorted(args.hit_pars): - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(out)) - final_hit_dict = { - "pars": hit_dicts[fk.timestamp], - "results": { - "ecal": results_dicts[fk.timestamp], - "partition_ecal": ecal_results, - }, - } - pathlib.Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True) - with open(out, "w") as w: - json.dump(final_hit_dict, w, indent=4) - -for out in args.fit_results: - fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(out)) - final_object_dict = { - "ecal": object_dict[fk.timestamp], - "partition_ecal": ecal_obj, - } - pathlib.Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True) - with open(out, "wb") as w: - pkl.dump(final_object_dict, w, protocol=pkl.HIGHEST_PROTOCOL) + + for out in sorted(args.hit_pars): + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(out)) + final_hit_dict = { + "pars": cal_dict[fk.timestamp], + "results": dict(**results_dicts[fk.timestamp], partition_ecal=ecal_results), + } + pathlib.Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True) + with open(out, "w") as w: + json.dump(final_hit_dict, w, indent=4) + + for out in args.fit_results: + fk = ChannelProcKey.get_filekey_from_pattern(os.path.basename(out)) + final_object_dict = dict(**object_dict[fk.timestamp], partition_ecal=full_object_dict) + pathlib.Path(os.path.dirname(out)).mkdir(parents=True, exist_ok=True) + with open(out, "wb") as w: + pkl.dump(final_object_dict, w, protocol=pkl.HIGHEST_PROTOCOL) diff --git a/scripts/pars_pht_qc.py b/scripts/pars_pht_qc.py new file mode 100644 index 0000000..18ff865 --- /dev/null +++ b/scripts/pars_pht_qc.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import pathlib +import pickle as pkl +import re +import warnings + +os.environ["PYGAMA_PARALLEL"] = "false" +os.environ["PYGAMA_FASTMATH"] = "false" + +import numpy as np +from legendmeta import LegendMetadata +from legendmeta.catalog import Props +from lgdo.lh5 import ls +from pygama.pargen.data_cleaning import ( + generate_cut_classifiers, + get_keys, + get_tcm_pulser_ids, +) +from pygama.pargen.utils import load_data + +log = logging.getLogger(__name__) + +warnings.filterwarnings(action="ignore", category=RuntimeWarning) + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--cal_files", help="cal_files", nargs="*", type=str) + argparser.add_argument("--fft_files", help="fft_files", nargs="*", type=str) + argparser.add_argument( + "--tcm_filelist", help="tcm_filelist", nargs="*", type=str, required=False + ) + argparser.add_argument( + "--pulser_files", help="pulser_file", nargs="*", type=str, required=False + ) + argparser.add_argument( + "--overwrite_files", help="overwrite_files", nargs="*", type=str, required=False + ) + + argparser.add_argument("--configs", help="config", type=str, required=True) + argparser.add_argument("--datatype", help="Datatype", type=str, required=True) + argparser.add_argument("--timestamp", help="Timestamp", type=str, required=True) + argparser.add_argument("--channel", help="Channel", type=str, required=True) + + argparser.add_argument("--log", help="log_file", type=str) + + argparser.add_argument("--plot_path", help="plot_path", type=str, nargs="*", required=False) + argparser.add_argument( + "--save_path", + help="save_path", + type=str, + nargs="*", + ) + args = argparser.parse_args() + + logging.basicConfig(level=logging.DEBUG, filename=args.log, filemode="w") + logging.getLogger("numba").setLevel(logging.INFO) + logging.getLogger("parse").setLevel(logging.INFO) + logging.getLogger("lgdo").setLevel(logging.INFO) + logging.getLogger("h5py").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + logging.getLogger("legendmeta").setLevel(logging.INFO) + + # get metadata dictionary + configs = LegendMetadata(path=args.configs) + channel_dict = configs.on(args.timestamp, system=args.datatype)["snakemake_rules"] + channel_dict = channel_dict["pars_pht_qc"]["inputs"]["qc_config"][args.channel] + + # sort files in dictionary where keys are first timestamp from run + if isinstance(args.cal_files, list): + cal_files = [] + for file in args.cal_files: + with open(file) as f: + cal_files += f.read().splitlines() + else: + with open(args.cal_files) as f: + cal_files = f.read().splitlines() + + cal_files = sorted( + np.unique(cal_files) + ) # need this as sometimes files get double counted as it somehow puts in the p%-* filelist and individual runs also + + kwarg_dict = Props.read_from(channel_dict) + + if args.overwrite_files: + overwrite = Props.read_from(args.overwrite_files)[args.channel]["pars"]["operations"] + else: + overwrite = None + + kwarg_dict_fft = kwarg_dict["fft_fields"] + if len(args.fft_files) > 0: + # sort files in dictionary where keys are first timestamp from run + if isinstance(args.fft_files, list): + fft_files = [] + for file in args.fft_files: + with open(file) as f: + fft_files += f.read().splitlines() + else: + with open(args.fft_files) as f: + fft_files = f.read().splitlines() + + fft_files = sorted( + np.unique(fft_files) + ) # need this as sometimes files get double counted as it somehow puts in the p%-* filelist and individual runs also + + if len(fft_files) > 0: + fft_fields = get_keys( + [ + key.replace(f"{args.channel}/dsp/", "") + for key in ls(fft_files[0], f"{args.channel}/dsp/") + ], + kwarg_dict_fft["cut_parameters"], + ) + + fft_data = load_data( + fft_files, + f"{args.channel}/dsp", + {}, + [*fft_fields, "timestamp", "trapTmax", "t_sat_lo"], + ) + + discharges = fft_data["t_sat_lo"] > 0 + discharge_timestamps = np.where(fft_data["timestamp"][discharges])[0] + is_recovering = np.full(len(fft_data), False, dtype=bool) + for tstamp in discharge_timestamps: + is_recovering = is_recovering | np.where( + ( + ((fft_data["timestamp"] - tstamp) < 0.01) + & ((fft_data["timestamp"] - tstamp) > 0) + ), + True, + False, + ) + fft_data["is_recovering"] = is_recovering + + hit_dict_fft = {} + plot_dict_fft = {} + cut_data = fft_data.query("is_recovering==0") + log.debug(f"cut_data shape: {len(cut_data)}") + for name, cut in kwarg_dict_fft["cut_parameters"].items(): + cut_dict, cut_plots = generate_cut_classifiers( + cut_data, + {name: cut}, + kwarg_dict.get("rounding", 4), + display=1 if args.plot_path else 0, + ) + hit_dict_fft.update(cut_dict) + plot_dict_fft.update(cut_plots) + + log.debug(f"{name} calculated cut_dict is: {json.dumps(cut_dict, indent=2)}") + + ct_mask = np.full(len(cut_data), True, dtype=bool) + for outname, info in cut_dict.items(): + # convert to pandas eval + exp = info["expression"] + for key in info.get("parameters", None): + exp = re.sub(f"(? 0 + discharge_timestamps = np.where(data["timestamp"][discharges])[0] + is_recovering = np.full(len(data), False, dtype=bool) + for tstamp in discharge_timestamps: + is_recovering = is_recovering | np.where( + (((data["timestamp"] - tstamp) < 0.01) & ((data["timestamp"] - tstamp) > 0)), + True, + False, + ) + data["is_recovering"] = is_recovering + + rng = np.random.default_rng() + mask = np.full(len(data.query("~is_pulser & ~is_recovering")), False, dtype=bool) + mask[ + rng.choice( + len(data.query("~is_pulser & ~is_recovering")), + 2000 * len(args.cal_files), + replace=False, + ) + ] = True + + if "initial_cal_cuts" in kwarg_dict: + init_cal = kwarg_dict["initial_cal_cuts"] + hit_dict_init_cal, plot_dict_init_cal = generate_cut_classifiers( + data.query("~is_pulser")[mask], + init_cal["cut_parameters"], + init_cal.get("rounding", 4), + display=1 if args.plot_path else 0, + ) + ct_mask = np.full(len(data), True, dtype=bool) + for outname, info in hit_dict_init_cal.items(): + # convert to pandas eval + exp = info["expression"] + for key in info.get("parameters", None): + exp = re.sub(f"(? 1000) & ( + puls["trapTmax"].nda < 200 + ) + bl_mask = np.append(bl_mask, bl_idxs) + else: + with open(args.phy_files) as f: + phy_files = f.read().splitlines() + phy_files = sorted(np.unique(phy_files)) + bls = sto.read("ch1027200/dsp/", phy_files, field_mask=["wf_max", "bl_mean"])[0] + puls = sto.read("ch1027201/dsp/", phy_files, field_mask=["trapTmax"])[0] + bl_mask = ((bls["wf_max"].nda - bls["bl_mean"].nda) > 1000) & (puls["trapTmax"].nda < 200) + + kwarg_dict = Props.read_from(channel_dict) + kwarg_dict_fft = kwarg_dict["fft_fields"] + + cut_fields = get_keys( + [ + key.replace(f"{args.channel}/dsp/", "") + for key in ls(phy_files[0], f"{args.channel}/dsp/") + ], + kwarg_dict_fft["cut_parameters"], + ) + + data = sto.read( + f"{args.channel}/dsp/", + phy_files, + field_mask=[*cut_fields, "daqenergy", "t_sat_lo", "timestamp"], + idx=np.where(bl_mask)[0], + )[0].view_as("pd") + + discharges = data["t_sat_lo"] > 0 + discharge_timestamps = np.where(data["timestamp"][discharges])[0] + is_recovering = np.full(len(data), False, dtype=bool) + for tstamp in discharge_timestamps: + is_recovering = is_recovering | np.where( + (((data["timestamp"] - tstamp) < 0.01) & ((data["timestamp"] - tstamp) > 0)), + True, + False, + ) + data["is_recovering"] = is_recovering + + log.debug(f"{len(discharge_timestamps)} discharges found in {len(data)} events") + + hit_dict = {} + plot_dict = {} + cut_data = data.query("is_recovering==0") + log.debug(f"cut_data shape: {len(cut_data)}") + for name, cut in kwarg_dict_fft["cut_parameters"].items(): + cut_dict, cut_plots = generate_cut_classifiers( + cut_data, + {name: cut}, + kwarg_dict.get("rounding", 4), + display=1 if args.plot_path else 0, + ) + hit_dict.update(cut_dict) + plot_dict.update(cut_plots) + + log.debug(f"{name} calculated cut_dict is: {json.dumps(cut_dict, indent=2)}") + + ct_mask = np.full(len(cut_data), True, dtype=bool) + for outname, info in cut_dict.items(): + # convert to pandas eval + exp = info["expression"] + for key in info.get("parameters", None): + exp = re.sub(f"(?