Source code for ctlearn_manager.tri_model_collection

"""Plotting and predicting with a collection of CTLearnTriModelManager models."""

import os

import astropy.units as u
import ctadata
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from . import CTLearnTriModelManager

from .utils.utils import (
    ClusterConfiguration,
    Cuts,
    CutType,
    DefaultCuts,
    ParticleType,
    angular_distance,
    get_files_cscs,
    get_files_LST_cluster,
    set_mpl_style,
    CTLMDirectories,
    get_color
)

__all__ = ["TriModelCollection"]


[docs] class TriModelCollection: """ TriModelCollection is a class that manages a collection of tri-models. Use for predicting and analyzing data from telescopes. It provides methods for predicting data, finding the closest model to a given pointing, and visualizing various performance metrics. Attributes ---------- tri_models : list A list of tri-models to be managed by the collection. cluster_configuration : ClusterConfiguration Configuration for the cluster where the models will run. model_labels : list[str] Labels for the tri-models in the collection. Methods ------- __init__(tri_models, cluster_configuration, model_labels) Initialize the TriModelCollection with a list of tri-models, cluster configuration, and optional model labels. predict_lstchain_run(run, output_dir, DL1_data_dir=None, overwrite=False, plot=False, batch_size=64) Predict data for a specific LST chain run and save the results. predict_lstchain_data(input_file, output_file, pointing_table, config_dir=None, overwrite=False, run=None, subrun=None, plot=False, batch_size=64) Predict data for a specific input file using the closest tri-model. predict_data(input_file, output_file, pointing_table, config_dir=None, overwrite=False, plot=False) Predict data for a given input file and save the results. find_closest_model_to(input_file, pointing_table, plot=False, alt_key="alt_tel", az_key="az_tel", verbose=True) Find the closest tri-model to a given pointing based on angular distance. plot_zenith_azimuth_ranges(plot_testing_nodes=True) Plot the zenith and azimuth ranges for all tri-models in the collection. plot_energy_resolution_DL2(cuts, zenith=None, azimuth=None, ylim=None, particle_type, figsize=None, plot_RF=False, compare_with=None) Plot the energy resolution for DL2 data with optional comparison to a reference model. plot_angular_resolution_DL2(cuts, zenith=None, azimuth=None, ylim=None, particle_type, figsize=None, plot_RF=False, compare_with=None) Plot the angular resolution for DL2 data with optional comparison to a reference model. plot_cuts(cuts) Plot the cuts applied to the data for all tri-models in the collection. plot_everything_dl2(output_directory, dl2_files, gammaness_cut=0.9, edep_cuts=False, pointing_table="/dl1/monitoring/telescope/pointing/tel_001") Generate and save plots for angular resolution, energy resolution, and gammaness for DL2 data. """
[docs] def __init__( self, tri_models: list[CTLearnTriModelManager], cluster_configuration: ClusterConfiguration = ClusterConfiguration(), model_labels: list[str] = None, allow_muliple_projects=False, ): """ Initialize the TriModelCollection object. Parameters ---------- tri_models : list A list of tri-model objects to be included in the collection. cluster_configuration : ClusterConfiguration, optional The cluster configuration to be applied to all tri-models in the collection. Defaults to a new instance of `ClusterConfiguration`. model_labels : list of str, optional A list of labels for the tri-models. If not provided, default labels in the format "Model_{j}" will be generated. Attributes ---------- tri_models : list The list of tri-models in the collection. cluster_configuration : ClusterConfiguration The cluster configuration applied to all tri-models. model_labels : list of str The labels for the tri-models in the collection. Raises ------ AssertionError If `model_labels` is provided and its length does not match the number of tri-models. If the stereos of all tri-models are not the same. """ self.tri_models: list[CTLearnTriModelManager] = tri_models self.cluster_configuration = cluster_configuration # Assert that all tri_models have the same project_directory self.allow_muliple_projects = allow_muliple_projects if not self.allow_muliple_projects: project_directories = [tri_model.project_directories.project_directory for tri_model in self.tri_models] assert len(set(project_directories)) == 1, "All tri_models must be part of the same project_directory." project_directories = [tri_model.project_directories for tri_model in self.tri_models] self.project_directories: CTLMDirectories = project_directories[0] for tri_model in self.tri_models: tri_model.cluster_configuration = cluster_configuration telescope_ids = [tri_model.telescope_ids for tri_model in self.tri_models] telescope_names = [tri_model.telescope_names for tri_model in self.tri_models] stereos = [tri_model.stereo for tri_model in self.tri_models] if model_labels is not None: assert len(model_labels) == len(self.tri_models), ( "Model labels must be the same length as the number of tri models." ) self.model_labels = model_labels else: self.model_labels = [f"Model_{j}" for j in range(len(self.tri_models))] assert len(set(stereos)) == 1, "All stereos in the collection must be the same." # set_mpl_style() self.tri_model_nicknames = [tri_model.project_directories.tri_model_nickname for tri_model in self.tri_models]
# assert len(set(telescope_ids)) == 1, "All telescope_ids in the collection must be the same." # assert len(set(telescope_names)) == 1, "All telescope_names in the collection must be the same."
[docs] def predict_lstchain_run( self, run: int, output_dir: str, DL1_data_dir=None, overwrite=False, plot=False, batch_size=64, ): """ Predict DL2 data for a given LST run using the specified cluster configuration. Parameters ---------- run : int The run number to process. output_dir : str Directory where the output DL2 files will be saved. DL1_data_dir : str, optional Directory containing the DL1 input data. If None, a default path is used based on the cluster configuration. overwrite : bool, default=False Whether to overwrite existing DL2 files. plot : bool, default=False Whether to generate and save plots during the prediction process. batch_size : int, default=64 Batch size to use during prediction. Raises ------ ValueError If the cluster configuration is not recognized (neither 'cscs' nor 'lst-cluster'). Notes ----- - For the 'cscs' cluster, DL1 files are fetched from dCache and copied to a scratch directory before prediction. - For the 'lst-cluster', DL1 files are directly accessed from the specified directory. - The method generates DL2 files for each subrun of the specified run. """ os.makedirs(output_dir, exist_ok=True) if self.cluster_configuration.cluster == "cscs": if DL1_data_dir is None: DL1_data_dir = "/pnfs/cta.cscs.ch/lst/DL1/" input_files, v = get_files_cscs(run, DL1_data_dir) scratch_dir = os.getenv("SCRATCH") scratch_dl1_dir = f"{scratch_dir}/ctlearn_manager_dl1_from_dcache/{run:05d}/{v}/tailcut84/" os.system(f"mkdir -p {scratch_dl1_dir}") current_directory = os.getcwd() print(f"DL1 files will be copied to {scratch_dl1_dir}\n") for dcache_file in input_files: input_file = f"{scratch_dl1_dir}/{dcache_file.split('/')[-1]}" subrun = int(input_file.split(".")[-2]) output_file = f"{output_dir}/LST-1.Run{run:05d}.{subrun:04d}.dl2.h5" if os.path.exists(output_file) and not overwrite: print( f"⚠️ Output file already exists and overwrite is set to False : {output_file}" ) continue if not os.path.exists(input_file): print(f"⌛ Copying {dcache_file} to {scratch_dl1_dir}") ctadata.fetch_and_save_file_or_dir(dcache_file) os.system( f"mv {current_directory}/{dcache_file.split('/')[-1]} {scratch_dl1_dir}/{dcache_file.split('/')[-1]}" ) # print(f"Predicting {input_file}") self.predict_lstchain_data( input_file, output_file, config_dir=output_dir, overwrite=overwrite, run=run, subrun=subrun, plot=plot, batch_size=batch_size, ) elif self.cluster_configuration.cluster == "lst-cluster": if DL1_data_dir is None: DL1_data_dir = "/fefs/aswg/data/real/DL1/" input_files = get_files_LST_cluster(run, DL1_data_dir) for input_file in input_files: print(f"Predicting {input_file}") subrun = int(input_file.split(".")[-2]) output_file = f"{output_dir}/LST-1.Run{run:05d}.{subrun:04d}.dl2.h5" self.predict_lstchain_data( input_file, output_file, config_dir=output_dir, overwrite=overwrite, run=run, subrun=subrun, plot=plot, batch_size=batch_size, ) else: raise ValueError( f"To predict LST data run-wise, the cluster must be either 'cscs' or 'lst-cluster'. Current cluster : {self.cluster_configuration.cluster}" )
[docs] def predict_lstchain_data( self, input_file, output_file, pointing_table="/dl1/event/telescope/parameters/LST_LSTCam", config_dir=None, overwrite=False, run=None, subrun=None, plot=False, batch_size=64, ): """ Predict data using the closest trained model for lstchain data. This method identifies the closest trained model to the input data and uses it to make predictions. The results are saved to the specified output file. Parameters ---------- input_file : str Path to the input file containing lstchain data. output_file : str Path to the output file where predictions will be saved. pointing_table : str, optional Path to the pointing table in the input file. Default is "/dl1/event/telescope/parameters/LST_LSTCam". config_dir : str, optional Directory containing configuration files for the model. Default is None. overwrite : bool, optional Whether to overwrite the output file if it already exists. Default is False. run : int, optional Run number to filter the data. Default is None. subrun : int, optional Subrun number to filter the data. Default is None. plot : bool, optional Whether to generate and display plots for model selection. Default is False. batch_size : int, optional Batch size to use during prediction. Default is 64. Returns ------- None This method does not return any value. The predictions are saved to the specified output file. Notes ----- If no suitable trained model is found, the method will return without performing any predictions. If the output file already exists and `overwrite` is set to False, the method will also return without performing any predictions. """ closest_tri_model = self.find_closest_model_to( input_file, pointing_table, plot=plot ) if os.path.exists(output_file) and not overwrite: print( f"⚠️ Output file already exists and overwrite is set to False : {output_file}" ) return if closest_tri_model is not None: closest_tri_model.predict_lstchain_data( input_file, output_file, config_dir=config_dir, overwrite=overwrite, run=run, subrun=subrun, pointing_table=pointing_table, batch_size=batch_size, ) else: return
[docs] def predict_data( self, input_file, output_file, pointing_table="dl0/monitoring/subarray/pointing", config_dir=None, overwrite=False, plot=False, ): """ Predict data using the closest trained model. Parameters ---------- input_file : str Path to the input file containing the data to be predicted. output_file : str Path to the output file where the predictions will be saved. pointing_table : str, optional Path to the pointing table within the input file, by default "dl0/monitoring/subarray/pointing". config_dir : str, optional Directory containing configuration files for the prediction, by default None. overwrite : bool, optional Whether to overwrite the output file if it already exists, by default False. plot : bool, optional Whether to generate and display plots during the process, by default False. Returns ------- None This method does not return anything. If no suitable model is found, the function exits early. """ closest_tri_model = self.find_closest_model_to( input_file, pointing_table, plot=plot ) if closest_tri_model is not None: closest_tri_model.predict_data( input_file, output_file, config_dir=config_dir, overwrite=overwrite, pointing_table=pointing_table, ) else: return
[docs] def find_closest_model_to( self, input_file, pointing_table, plot=False, alt_key="alt_tel", az_key="az_tel", verbose=True, ): """ Find the closest model to the average pointing of a given input file. Parameters ---------- input_file : str Path to the input file containing pointing data. pointing_table : str Path to the pointing table file. plot : bool, optional If True, plot the pointing and model ranges on a polar plot. Default is False. alt_key : str, optional Key for altitude in the pointing table. Default is "alt_tel". az_key : str, optional Key for azimuth in the pointing table. Default is "az_tel". verbose : bool, optional If True, print detailed information about the closest model. Default is True. Returns ------- closest_model : object The model from `self.tri_models` that is closest to the average pointing. Notes ----- This method calculates the average pointing (zenith and azimuth) of the input file and compares it to the average pointing ranges of the models in `self.tri_models`. The closest model is determined based on the angular distance. Raises ------ Exception If the pointing data cannot be found in the provided pointing table. """ import astropy.units as u from ctlearn_manager.utils.utils import get_avg_pointing try: avg_data_ze, avg_data_az = get_avg_pointing( input_file, pointing_table=pointing_table, alt_key=alt_key, az_key=az_key, ) except: # print(f"⚠️ Error reading pointing data from {pointing_table}: {e}") print(f"⚠️ Pointing not found at {pointing_table}, skipping : {input_file}") return avg_model_azs = [] avg_model_zes = [] for tri_model in self.tri_models: avg_model_azs.append( np.mean(tri_model.direction_model.validity.azimuth_range) .to(u.deg) .value ) avg_model_zes.append( np.mean(tri_model.direction_model.validity.zenith_range).to(u.deg).value ) avg_model_azs = np.array(avg_model_azs) * u.deg avg_model_zes = np.array(avg_model_zes) * u.deg # angular_distance_matrix = angular_distance(avg_data_ze, avg_data_az, avg_model_zes, avg_model_azs) closest_model_index = np.argmin( angular_distance(avg_data_ze, avg_data_az, avg_model_zes, avg_model_azs) ) closest_model = self.tri_models[closest_model_index] if verbose: print( f"📁 File : {input_file.split('/')[-1]} 📡 Pointing : ({avg_data_ze.value:.3f}, {avg_data_az.value:.3f}) 🧠 Closest Model : ({np.mean(closest_model.direction_model.validity.zenith_range).value:.3f}, {np.mean(closest_model.direction_model.validity.azimuth_range).value:.3f})" ) # print(f"|📡 Average pointing of {input_file.split('/')[-1]} : ({avg_data_ze:3f}, {avg_data_az:3f})") # print(f"|🔍 Closest model avg node : ({np.mean(closest_model.direction_model.validity.zenith_range).value}, {np.mean(closest_model.direction_model.validity.azimuth_range).value})") # print(f"|🧠 Using models {closest_model.direction_model.model_nickname}, {closest_model.energy_model.model_nickname} and {closest_model.type_model.model_nickname}") if plot: import matplotlib.pyplot as plt fig, ax = plt.subplots(subplot_kw={"projection": "polar"}) closest_model.direction_model.plot_zenith_azimuth_ranges(ax) ax.scatter( avg_data_az.to(u.rad), avg_data_ze, label="Average pointing", color=get_color("ctlearn_highlight"), ) ax.legend() plt.show() return closest_model
[docs] def plot_zenith_azimuth_ranges(self, plot_testing_nodes=True): """ Plot the zenith and azimuth ranges for all tri-models in a polar plot. Parameters ---------- plot_testing_nodes : bool, optional If True, include testing nodes in the plot. Default is True. Notes ----- This method creates a polar plot using Matplotlib and iterates over all tri-models to plot their respective zenith and azimuth ranges. The plot is displayed using `plt.show()`. """ import matplotlib.pyplot as plt fig, ax = plt.subplots(subplot_kw={"projection": "polar"}) for tri_model in self.tri_models: tri_model.direction_model.plot_zenith_azimuth_ranges( ax, plot_testing_nodes=plot_testing_nodes ) plt.show()
[docs] @u.quantity_input(zenith=u.deg, azimuth=u.deg) def plot_energy_resolution_DL2( self, cuts: Cuts = DefaultCuts.GH_0_9.value, zenith: float = None, azimuth: float = None, ylim=None, particle_type: ParticleType = ParticleType.GAMMA_POINT, figsize=None, plot_RF=False, compare_with: str = None, output_file=None, ): """ Plot the energy resolution for DL2 data. Parameters ---------- cuts : Cuts, optional The cuts to apply for the analysis. Defaults to `DefaultCuts.GH_0_9.value`. zenith : float, optional The zenith angle in degrees. Required if `plot_RF` is True or `compare_with` is provided. azimuth : float, optional The azimuth angle in degrees. Required if `compare_with` is provided. ylim : tuple, optional The y-axis limits for the plot. particle_type : ParticleType, optional The type of particle to analyze. Defaults to `ParticleType.GAMMA_POINT`. figsize : tuple, optional The size of the figure. plot_RF : bool, optional Whether to include the Random Forest (RF) benchmark in the plot. Defaults to False. compare_with : str, optional The label of the model to compare with. If provided, zenith and azimuth must also be specified. Raises ------ ValueError If `compare_with` is provided but `zenith` or `azimuth` is not specified. Notes ----- - If `plot_RF` is True, the function will load the corresponding IRF file based on the zenith angle. - When `compare_with` is specified, the function will plot the relative improvement of the energy resolution compared to the reference model. """ compare_with_index = [ i for i, label in enumerate(self.model_labels) if label == compare_with ] if compare_with is not None: fig, (ax, ax_rel) = plt.subplots( 2, 1, gridspec_kw={"height_ratios": [3, 1]} ) ax_rel.set_xlabel("True Energy (TeV)") ax_rel.set_ylabel("Rel. Impr. (%)") ax_rel.grid(True, linestyle="--", alpha=0.5) ax_rel.set_xscale("log") ax_rel.set_ymargin(0.05) ax_rel.set_yticks([0, 10, 20, 30, 40, 50]) else: fig, ax = plt.subplots() cuts.plot_cuts_info_plt(ax) if ( plot_RF and cuts.cut_type == CutType.EFFICIENCY_OPTIMIZED and zenith is not None ): import importlib import importlib.resources as pkg_resources from astropy.io import fits module_name = "ctlearn_manager.resources.irfs.LST1" RF_bechmpark = importlib.import_module(module_name) available_zeniths = [10.00, 23.63, 32.06, 43.20] closest_zenith = min(available_zeniths, key=lambda x: abs(x - zenith.value)) with pkg_resources.path( RF_bechmpark, f"irfs_zen_{closest_zenith:.2f}_gh-eff_{cuts.efficiency_gammaness}.fits.gz", ) as irf_file: # irf_file = "/users/blacave/PhD/Software/CTLearn-Manager/src/ctlearn_manager/resources/irfs/LST1/irfs_zen_10.00_gh-eff_0.7.fits.gz" hudl = fits.open(irf_file) # plt.plot(hudl['ANGULAR_RESOLUTION'].data['true_energy_center'],hudl['ANGULAR_RESOLUTION'].data['angular_resolution']) RF_e = hudl["ENERGY_BIAS_RESOLUTION"].data["true_energy_center"] RF_e_res = hudl["ENERGY_BIAS_RESOLUTION"].data["resolution"] l = f"RF {closest_zenith:.1f}°" if f"{zenith.value:.2f}" == f"{closest_zenith:.2f}": l = "RF" ax.plot(RF_e, RF_e_res, label=l, color=get_color('on_background'), zorder=0) if zenith is not None and azimuth is not None: zeniths = np.array([zenith.value]) * zenith.unit azimuths = np.array([azimuth.value]) * azimuth.unit text_color = get_color("ctlearn_accent_2") background_color = get_color("ctlearn_accent_1") ax.text( 0.02, 0.02, f"Pointing: ({zenith.value:.1f}, {azimuth.value:.1f})°", transform=ax.transAxes, fontsize=9, color=text_color, verticalalignment="bottom", horizontalalignment="left", bbox=dict( boxstyle="round,pad=0.3", edgecolor="none", facecolor=background_color, alpha=0.2, ), ) if compare_with is not None: # ax.set_xscale(ax_rel.get_xscale()) ax.set_xticks([]) ax.set_xlabel("") fig.subplots_adjust(hspace=0) if len(compare_with_index) > 0: ref_e_bins, ref_e_res_err = self.tri_models[ compare_with_index[0] ].get_energy_resolution_DL2( zenith=zenith, azimuth=azimuth, cuts=cuts, particle_type=particle_type, ) ref_e = (ref_e_bins[:-1] + ref_e_bins[1:]) / 2 ref_e_res = [e_r[0] for e_r in ref_e_res_err] elif compare_with == "RF" and plot_RF: ref_e = RF_e ref_e_res = RF_e_res ax_rel.plot( ref_e, [0] * len(ref_e), label=f"{compare_with} vs {compare_with}", color=get_color("on_background"), zorder=0, ) for tri_model, label in tqdm( zip(self.tri_models, self.model_labels), desc="Plotting energy resolution improvment", unit="model", total=len(self.tri_models), ): try: e_bins, e_res_err = tri_model.get_energy_resolution_DL2( zenith=zenith, azimuth=azimuth, cuts=cuts, particle_type=particle_type, ) except: continue e = (e_bins[:-1] + e_bins[1:]) / 2 e_res = [e_r[0] for e_r in e_res_err] # print(e, ref_e, ref_e_res) if not np.array_equal(e, ref_e): ref_e_res_interp = np.interp(e.value, ref_e, ref_e_res) else: ref_e_res_interp = ref_e_res relative_improvement = ( 100 * (np.array(ref_e_res_interp) - np.array(e_res)) / np.array(ref_e_res_interp) ) ax_rel.plot( e, relative_improvement, label=f"{label} vs {compare_with}" ) else: if compare_with is not None: raise ValueError( "If you want to compare with another model, you need to provide zenith and azimuth." ) zeniths = None azimuths = None for tri_model, label in tqdm( zip(self.tri_models, self.model_labels), desc="Plotting energy resolution", unit="model", total=len(self.tri_models), ): l = tri_model.energy_model.model_nickname if label is None else label tri_model.plot_energy_resolution_DL2( zeniths=zeniths, azimuths=azimuths, cuts=[cuts], ylim=ylim, particle_type=particle_type, ax=ax, figsize=figsize, label=l, ) if compare_with is not None: ax_rel.set_xlim(ax.get_xlim()) ax_rel.set_ylim(bottom=0) ax.legend() plt.tight_layout() plt.subplots_adjust(hspace=0.0) if output_file is not None: plt.savefig(output_file, dpi=300) # print(f"Saved plot to {output_file}") else: plt.show()
[docs] def plot_angular_resolution_DL2( self, cuts: Cuts = DefaultCuts.GH_0_9.value, zenith: float = None, azimuth: float = None, ylim=None, particle_type: ParticleType = ParticleType.GAMMA_POINT, figsize=None, plot_RF=False, compare_with: str = None, output_file=None, ): compare_with_index = [ i for i, label in enumerate(self.model_labels) if label == compare_with ] if compare_with is not None: fig, (ax, ax_rel) = plt.subplots( 2, 1, gridspec_kw={"height_ratios": [3, 1]} ) ax_rel.set_xlabel("True Energy (TeV)") ax_rel.set_ylabel("Rel. Impr. (%)") ax_rel.grid(True, linestyle="--", alpha=0.5) ax_rel.set_xscale("log") ax_rel.set_ymargin(0.05) ax_rel.set_yticks([0, 10, 20, 30, 40, 50]) else: fig, ax = plt.subplots() stored_efficiency_theta = cuts.efficiency_theta cuts.efficiency_theta = None cuts.plot_cuts_info_plt(ax) cuts.efficiency_theta = stored_efficiency_theta if ( plot_RF and cuts.cut_type == CutType.EFFICIENCY_OPTIMIZED and zenith is not None ): import importlib import importlib.resources as pkg_resources from astropy.io import fits module_name = "ctlearn_manager.resources.irfs.LST1" RF_bechmpark = importlib.import_module(module_name) available_zeniths = [10.00, 23.63, 32.06, 43.20] closest_zenith = min(available_zeniths, key=lambda x: abs(x - zenith.value)) with pkg_resources.path( RF_bechmpark, f"irfs_zen_{closest_zenith:.2f}_gh-eff_{cuts.efficiency_gammaness}.fits.gz", ) as irf_file: # irf_file = "/users/blacave/PhD/Software/CTLearn-Manager/src/ctlearn_manager/resources/irfs/LST1/irfs_zen_10.00_gh-eff_0.7.fits.gz" hudl = fits.open(irf_file) RF_e = hudl["ANGULAR_RESOLUTION"].data["true_energy_center"] RF_ang_res = hudl["ANGULAR_RESOLUTION"].data["angular_resolution"] l = f"RF {closest_zenith:.1f}°" if f"{zenith.value:.2f}" == f"{closest_zenith:.2f}": l = "RF" ax.plot(RF_e, RF_ang_res, label=l, color=get_color('on_background'), zorder=0) # ax.plot(hudl['ENERGY_BIAS_RESOLUTION'].data['true_energy_center'],hudl['ENERGY_BIAS_RESOLUTION'].data['resolution'], label='RF', color='k', zorder=0) if zenith is not None and azimuth is not None: zeniths = np.array([zenith.value]) * zenith.unit azimuths = np.array([azimuth.value]) * azimuth.unit text_color = get_color("ctlearn_accent_2") background_color = get_color("ctlearn_accent_1") ax.text( 0.02, 0.02, f"Pointing: ({zenith.value:.1f}, {azimuth.value:.1f})°", transform=ax.transAxes, fontsize=9, color=text_color, verticalalignment="bottom", horizontalalignment="left", bbox=dict( boxstyle="round,pad=0.3", edgecolor="none", facecolor=background_color, alpha=0.2, ), ) if compare_with is not None: # ax.set_xscale(ax_rel.get_xscale()) ax.set_xticks([]) ax.set_xlabel("") fig.subplots_adjust(hspace=0) if len(compare_with_index) > 0: ref_e_bins, ref_ang_res_err = self.tri_models[ compare_with_index[0] ].get_angular_resolution_DL2( zenith=zenith, azimuth=azimuth, cuts=cuts, particle_type=particle_type, ) ref_e = (ref_e_bins[:-1].value + ref_e_bins[1:].value) / 2 ref_ang_res = [e_r[0].value for e_r in ref_ang_res_err] elif compare_with == "RF" and plot_RF: ref_e = RF_e ref_ang_res = RF_ang_res ax_rel.plot( ref_e, [0] * len(ref_e), label=f"{compare_with} vs {compare_with}", color=get_color("on_background"), zorder=0, ) for tri_model, label in tqdm( zip(self.tri_models, self.model_labels), desc="Plotting angular resolution improvment", unit="model", total=len(self.tri_models), ): try: e_bins, e_res_err = tri_model.get_angular_resolution_DL2( zenith=zenith, azimuth=azimuth, cuts=cuts, particle_type=particle_type, ) except: continue e = (e_bins[:-1].value + e_bins[1:].value) / 2 e_res = [e_r[0].value for e_r in e_res_err] if not np.array_equal(e, ref_e): ref_e_res_interp = np.interp(e, ref_e, ref_ang_res) else: ref_e_res_interp = ref_ang_res relative_improvement = ( 100 * (np.array(ref_e_res_interp) - np.array(e_res)) / np.array(ref_e_res_interp) ) ax_rel.plot( e, relative_improvement, label=f"{label} vs {compare_with}" ) # ax_rel.text(e[np.where(relative_improvement == np.max(relative_improvement))][0], np.max(relative_improvement), f"{int(np.max(relative_improvement))}", fontsize=8) else: zeniths = None azimuths = None for tri_model, label in tqdm( zip(self.tri_models, self.model_labels), desc="Plotting angular resolution", unit="model", total=len(self.tri_models), ): l = tri_model.direction_model.model_nickname if label is None else label tri_model.plot_angular_resolution_DL2( zeniths=zeniths, azimuths=azimuths, cuts=[cuts], ylim=ylim, particle_type=particle_type, ax=ax, figsize=figsize, label=l, ) if compare_with is not None: ax_rel.set_xlim(ax.get_xlim()) ax_rel.set_ylim(bottom=0) ax.legend() plt.tight_layout() plt.subplots_adjust(hspace=0.0) if output_file is not None: plt.savefig(output_file) else: plt.show()
[docs] def plot_cuts(self, cuts: Cuts = DefaultCuts.EFF_70.value): fig, axs = plt.subplots(1, 2, figsize=(10, 4)) cuts.plot_cuts_info_plt(axs[0]) cuts.plot_cuts_info_plt(axs[1]) for tri_model, label in tqdm( zip(self.tri_models, self.model_labels), desc="Plotting cuts", unit="model" ): l = tri_model.direction_model.model_nickname if label is None else label tri_model.plot_cuts(cuts=[cuts], axs=axs, label=l) axs[0].legend() axs[1].legend() plt.tight_layout() plt.show()
# def plot_benchmark(self, cuts: Cuts=DefaultCuts.GH_0_9.value, ylim=None, particle_type: ParticleType=ParticleType.GAMMA_POINT, figsize=None): # fig, axs = plt.subplots(1, 2, figsize=(10, 4)) # cuts.plot_cuts_info_plt(axs[0]) # cuts.plot_cuts_info_plt(axs[1]) # for tri_model, label in tqdm(zip(self.tri_models, self.model_labels), desc="Plotting benchmarking", unit="model"): # l = tri_model.direction_model.model_nickname if label is None else label # tri_model.plot_benchmark(cuts=[cuts], ylim=ylim, particle_type=particle_type, axs=axs, figsize=figsize, label=l) # axs[0].legend() # axs[1].legend() # plt.show()
[docs] def plot_everything_dl2( self, # output_directory: str, dl2_files: list[str], gammaness_cut: float = 0.9, edep_cuts: bool = False, pointing_table: str = "/dl1/monitoring/telescope/pointing/tel_001", ): """ Plot the angular resolution, energy resolution, and gammaness for DL2 data. This function generates plots for the angular resolution, energy resolution, and gammaness for the given DL2 files. It uses ctaplot to create the plots and saves them in the specified output directory. Parameters ---------- output_directory : str The directory where the plots will be saved. dl2_files : list[str] List of DL2 files to be processed. dl2_processed_dir : str The directory where the processed DL2 files are stored. gammaness_cut : float, optional The gammaness cut value to be applied. Default is 0.9. Returns ------- None """ import os import pickle import concurrent.futures grouped_files = {tri_model: [] for tri_model in self.tri_models} def assign_model(dl2_file): closest_tri_model = self.find_closest_model_to( dl2_file, pointing_table=pointing_table, plot=False, alt_key="altitude", az_key="azimuth", verbose=False, ) return (closest_tri_model, dl2_file) with concurrent.futures.ThreadPoolExecutor() as executor: results = list(tqdm( executor.map(assign_model, dl2_files), total=len(dl2_files), desc="Grouping DL2 files per model", unit="file" )) for closest_tri_model, dl2_file in results: if closest_tri_model is not None: grouped_files[closest_tri_model].append(dl2_file) # for dl2_file in tqdm( # dl2_files, desc="Grouping DL2 files per model", unit="file" # ): # closest_tri_model = self.find_closest_model_to( # dl2_file, # pointing_table=pointing_table, # plot=False, # alt_key="altitude", # az_key="azimuth", # verbose=False, # ) # if closest_tri_model is not None: # grouped_files[closest_tri_model].append(dl2_file) # Filter out empty groups grouped_files = { model: files for model, files in grouped_files.items() if files } n = [] models = [] for tri_model, files in grouped_files.items(): output_directory = tri_model.project_directories.dl2_post_processed_data_directory n.append(len(files)) models.append(tri_model.project_directories.tri_model_nickname) print( f"Processing {len(files)} files 🧠🧠🧠 CTLearnTriModelManager ▮ {tri_model.direction_model.model_nickname}{tri_model.energy_model.model_nickname}{tri_model.type_model.model_nickname} ▮" ) tri_model_file = f"{output_directory}/tri_model_{tri_model.direction_model.model_nickname}.pkl" tri_model.dl2_data_files = files use_cluster = tri_model.cluster_configuration.use_cluster tri_model.cluster_configuration.use_cluster = False # if some DL2 files were not processed, they will be processed in the same job as the plotting job, and not submit multiple new jobs os.makedirs(output_directory, exist_ok=True) with open(tri_model_file, "wb") as f: pickle.dump(tri_model, f) tri_model.cluster_configuration.use_cluster = use_cluster # print(edep_cuts) cmd = f"plot_dl2 --stereo_tri_model {tri_model_file} --output_directory {output_directory} --gammaness_cut {gammaness_cut} --edep_cuts={edep_cuts}" print(cmd) sbatch_file = tri_model.cluster_configuration.write_sbatch_script( f"dl2_plots_{tri_model.direction_model.model_nickname}", cmd, output_directory, use_gpu_cscs=False, ) if self.cluster_configuration.use_cluster: os.system(f"sbatch {sbatch_file}") else: os.system(cmd) if len(self.tri_models) > 1: plt.bar(models, n) plt.xlabel("CTLearn TriModel") plt.ylabel("Number of DL2 files") plt.xticks(rotation=45, ha='right') plt.tight_layout() plt.savefig("dl2_files_per_model.png") plt.show()
[docs] def get_tri_model_by_nickname(self, tri_model_nickname): """ Get a tri-model by its nickname. Parameters ---------- nickname : str The nickname of the tri-model to retrieve. Returns ------- CTLearnTriModel The tri-model with the specified nickname, or None if not found. """ for tri_model in self.tri_models: if tri_model.project_directories.tri_model_nickname == tri_model_nickname: return tri_model return None