Source code for ctlearn_manager.model_manager

"""
CTLearnModelManager Class.

This class is designed to manage CTLearn models, providing functionalities for initializing, saving, loading, training, and updating model parameters. It also includes methods for handling training data, testing data, DL2 data, and IRF data.
"""

import ast
from pathlib import Path

import astropy.units as u
import numpy as np
from astropy.table import QTable

from ctlearn_manager.utils.utils import (
    ClusterConfiguration,
    Cuts,
    DataSample,
    IRFType,
    ParticleType,
    get_irf_type_from_config,
    remove_row_from_table_utils,
    set_mpl_style,
    CTLMDirectories,
    get_color,
)

# from ctlearn_manager.utils.index_tables import IndexTables

__all__ = [
    "CTLearnModelManager",
    "DataSample",
    "ModelRangeOfValidity",
]


[docs] class CTLearnModelManager: """ CTLearnModelManager class for managing CTLearn models. This class provides methods for initializing, saving, loading, and training CTLearn models. It also includes methods for updating and retrieving model parameters, training data, testing data, DL2 data, and IRF data. Attributes ---------- model_index_file : str Path to the model index file. model_nickname : str Nickname of the model. model_parameters_table : astropy.table.Table Table containing model parameters. validity : ModelRangeOfValidity Range of validity for the model. stereo : bool Indicates if the model uses stereo mode. telescope_ids : list List of telescope IDs used in the model. telescope_names : list List of telescope names used in the model. cluster_configuration : ClusterConfiguration Configuration for the cluster. Methods ------- __init__(model_parameters, MODEL_INDEX_FILE, load=False, cluster_configuration=ClusterConfiguration()) Initializes the ModelManager instance. save_to_index(model_parameters) Save model parameters and training samples to an HDF5 index file. launch_training(n_epochs, transfer_learning_model_cpk=None, frozen_backbone=False, config_file=None) Launches the training process for the model. get_n_epoch_trained() Calculate the total number of epochs trained by summing the lengths of all training logs. plot_loss() Plots the training and validation loss over epochs. update_model_manager_parameters_in_index(parameters) Update the model manager parameters in the HDF5 index file. update_model_manager_testing_data(testing_gamma_dirs, testing_proton_dirs, testing_gamma_zenith_distances, testing_gamma_azimuths, testing_proton_zenith_distances, testing_proton_azimuths, testing_gamma_patterns, testing_proton_patterns) Update the model manager's testing data for gamma and proton events. update_model_manager_DL2_MC_files(testing_DL2_gamma_files, testing_DL2_proton_files, testing_DL2_gamma_zenith_distances, testing_DL2_gamma_azimuths, testing_DL2_proton_zenith_distances, testing_DL2_proton_azimuths) Update the DL2 MC files for gamma and proton testing data in the model manager. update_model_manager_DL2_data_files(DL2_files, DL2_zenith_distances, DL2_azimuths) Update the DL2 data files for the model manager. update_merged_DL2_MC_files(testing_DL2_zenith_distance, testing_DL2_azimuth, testing_DL2_gamma_merged_file=None, testing_DL2_proton_merged_file=None) Update the merged DL2 MC files for gamma and proton data. update_model_manager_IRF_data(config, cuts_file, irf_file, bencmark_file, zenith, azimuth) Update the IRF (Instrument Response Function) data for the model manager. get_IRF_data(zenith, azimuth) Retrieve the Instrument Response Function (IRF) data for a given zenith and azimuth. get_closest_IRF_data(zenith, azimuth) Retrieve the closest Instrument Response Function (IRF) data based on the given zenith and azimuth angles. get_DL2_MC_files(zenith, azimuth) Retrieve DL2 Monte Carlo (MC) files for given zenith and azimuth angles. plot_zenith_azimuth_ranges(ax=None) Plot the zenith and azimuth ranges on a polar plot. plot_training_nodes() Plot the training nodes for gamma and proton events on a polar plot. """
[docs] def __init__( self, model_parameters, # MODEL_INDEX_FILE, project_directories: CTLMDirectories, load=False, cluster_configuration=ClusterConfiguration(), ): """ Initialize the ModelManager class. Parameters ---------- model_parameters : dict Dictionary containing the parameters for the model. Must include at least the key "model_nickname" if a specific nickname is desired. MODEL_INDEX_FILE : str Path to the HDF5 file containing the model index. load : bool, optional If True, the model is loaded from the index file. If False, the model parameters are saved to the index file. Default is False. cluster_configuration : ClusterConfiguration, optional Configuration object for the cluster. Default is an instance of `ClusterConfiguration`. Raises ------ ValueError If the model is of type "reco" and the required training patterns for gamma diffuse or proton are missing. ValueError If stereo mode is enabled but fewer than 2 telescopes are provided. ValueError If the model is of type "reco" with "cameradirection" and stereo mode is enabled, as this combination is not supported. """ from astropy.io.misc.hdf5 import read_table_hdf5 # self.project_directories.model_index_file = MODEL_INDEX_FILE self.model_nickname = model_parameters.get("model_nickname", "new_model") self.project_directories = project_directories if not load: self.save_to_index(model_parameters) print(f"🧠 Model name: {self.model_nickname}") self.model_parameters_table = read_table_hdf5( f"{self.project_directories.model_index_file}", path=IndexTables(self).PARAMETERS.table_path ) self.validity = ModelRangeOfValidity(self) self.telescope_ids = ast.literal_eval( self.model_parameters_table["telescope_ids"][0] ) self.telescope_names = ast.literal_eval( self.model_parameters_table["telescope_names"][0] ) try: self.min_telescopes = self.model_parameters_table["min_telescopes"][0] except: self.min_telescopes = len(self.telescope_ids) self.stereo = self.model_parameters_table["stereo"][ 0 ] # True if self.min_telescopes >= 2 else False training_table_gamma = read_table_hdf5( f"{self.project_directories.model_index_file}", path=IndexTables(self, ParticleType.GAMMA_DIFFUSE).TRAINING.table_path, ) if self.model_parameters_table["reco"][0] == "type": training_table_proton = read_table_hdf5( f"{self.project_directories.model_index_file}", path=IndexTables(self, ParticleType.PROTON).TRAINING.table_path, ) if (len(training_table_proton["training_proton_patterns"]) == 0) or ( len(training_table_gamma["training_gamma_diffuse_patterns"]) == 0 ): raise ValueError( "For reco type, training_proton_patterns and training_gamma_diffuse_patterns are required" ) if self.stereo and len(self.telescope_ids) < 2: raise ValueError("For stereo mode, at least 2 telescopes are required") if self.model_parameters_table["reco"][0] == "cameradirection" and self.stereo: raise ValueError( "For reco cameradirection, stereo mode is not supported, use skydirection instead." ) # Check that all gamma related lists are the same length # gamma_lengths = [len(training_table_gamma['training_gamma_diffuse_patterns']), len(training_table_gamma['training_gamma_diffuse_zenith_distances']), len(training_table_gamma['training_gamma_diffuse_azimuths'])] # if len(set(gamma_lengths)) != 1: # raise ValueError("All gamma related lists must be the same length") # Check that all proton related lists are the same length # proton_lengths = [len(training_table_proton['training_proton_patterns']), len(training_table_proton['training_proton_zenith_distances']), len(training_table_proton['training_proton_azimuths'])] # if len(set(proton_lengths)) != 1: # raise ValueError("All proton related lists must be the same length") self.cluster_configuration = cluster_configuration
# set_mpl_style() # current_model_dir = self.model_parameters_table['model_dir'][0] # if f"/{self.model_nickname}" not in current_model_dir: # print("⚠️ Updating model directories for compatibility with the new version of CTLearnManager") # import os # import glob # os.system(f"mkdir {current_model_dir}/{self.model_nickname}") # model_dirs = glob.glob(f"{current_model_dir}/{self.model_nickname}_v*") # for model_dir in model_dirs: # print(f"➡️ Moving {model_dir} to {model_dir}/{self.model_nickname}") # os.system(f"mv {model_dir} {model_dir}/{self.model_nickname}/")
[docs] def save_to_index(self, model_parameters): """ Save model parameters and training samples to an HDF5 index file. Parameters ---------- model_parameters : dict A dictionary containing model parameters and training sample details. Expected keys include: - "model_dir" (str): Absolute path to the model directory. - "reco" (str, optional): Reconstruction type, one of ['type', 'energy', 'cameradirection', 'skydirection']. Defaults to "default_reco". - "channels" (list of str, optional): List of channel names. Defaults to ["cleaned_image", "cleaned_relative_peak_time"]. - "telescope_names" (list of str, optional): Names of telescopes. Defaults to an empty list. - "telescope_ids" (list of int, optional): IDs of telescopes. Defaults to an empty list. - "notes" (str, optional): Notes about the model. Defaults to an empty string. - "max_training_epochs" (int, optional): Maximum number of training epochs. Defaults to 10. - "min_telescopes" (int, optional): Minimum number of telescopes. Defaults to 1. - "stereo" (bool, optional): Whether the model uses stereo mode. Defaults to True if `min_telescopes` >= 2, otherwise False. - "training_samples" (list, optional): List of training sample objects. Each object must have the following attributes: - `particle_type` (enum): Type of particle. - `directory` (str): Directory of the training sample. - `pattern` (str): Pattern of the training sample. - `zenith_distance` (float): Zenith distance in degrees. - `azimuth` (float): Azimuth in degrees. - `energy_range` (tuple of float): Minimum and maximum energy in TeV. - `nsb_range` (tuple of float): Minimum and maximum NSB in Hz. Raises ------ ValueError If the "model_dir" is not an absolute path. AssertionError If any of the following conditions are not met: - Telescope names and IDs have the same length. - `telescope_ids` is a 1-dimensional array. - `telescope_names` is a 1-dimensional array. - `channels` is a 1-dimensional array. - `reco` is one of ['type', 'energy', 'cameradirection', 'skydirection']. - `max_training_epochs` is an integer. - `min_telescopes` is an integer. Notes ----- This method creates or updates an HDF5 file to store model parameters and training sample details. If the model nickname already exists in the index, it will not overwrite the existing entry. Training sample details are stored under paths specific to the particle type. """ from astropy.io.misc.hdf5 import read_table_hdf5, write_table_hdf5 paramaters_index_table = IndexTables(self).PARAMETERS try: model_table = QTable.read( self.project_directories.model_index_file, format="hdf5", path=paramaters_index_table.table_path, ) print(f"❌ Model nickname {self.model_nickname} already in table") except: model_table = paramaters_index_table.default_table notes = model_parameters.get("notes", "") if not Path(model_parameters.get("model_dir", "")).is_absolute(): raise ValueError("The 'model_dir' must be an absolute path.") model_dir = f"{model_parameters.get('model_dir', '')}/{self.model_nickname}" reco = model_parameters.get("reco", "default_reco") telescope_names = model_parameters.get("telescope_names", []) telescope_ids = model_parameters.get("telescope_ids", []) channels = model_parameters.get( "channels", ["cleaned_image", "cleaned_relative_peak_time"] ) max_training_epochs = model_parameters.get("max_training_epochs", 10) min_telescopes = model_parameters.get("min_telescopes", 1) stereo = model_parameters.get( "stereo", True if min_telescopes >= 2 else False ) assert len(telescope_names) == len(telescope_ids), ( "Telescope names and IDs must have the same length" ) assert np.ndim(telescope_ids) == 1, ( "telescope_ids must be a 1-dimensional array" ) assert np.ndim(telescope_names) == 1, ( "telescope_names must be a 1-dimensional array" ) assert np.ndim(channels) == 1, "channels must be a 1-dimensional array" assert reco in ["type", "energy", "cameradirection", "skydirection"], ( "reco must be one of ['type', 'energy', 'cameradirection', 'skydirection']" ) assert type(max_training_epochs) is int, ( "max_training_epochs must be an integer" ) assert type(min_telescopes) is int, "min_telescopes must be an integer" model_table.add_row( [ self.model_nickname, model_dir, reco, str(channels), str(telescope_names), str(telescope_ids), notes, max_training_epochs, min_telescopes, stereo, ] ) write_table_hdf5( model_table, self.project_directories.model_index_file, path=paramaters_index_table.table_path, append=True, overwrite=True, ) training_samples = model_parameters.get("training_samples", []) for training_sample in training_samples: particle_type = training_sample.particle_type training_index_table = IndexTables(self, particle_type).TRAINING try: training_table = read_table_hdf5( self.project_directories.model_index_file, path=training_index_table.table_path, ) except: training_table = training_index_table.default_table training_table.add_row( [ training_sample.directory, training_sample.pattern, training_sample.zenith_distance, training_sample.azimuth, min(training_sample.energy_range), max(training_sample.energy_range), min(training_sample.nsb_range), max(training_sample.nsb_range), ] ) write_table_hdf5( training_table, self.project_directories.model_index_file, path=training_index_table.table_path, append=True, overwrite=True, serialize_meta=True, ) print(f"✅ Model nickname {self.model_nickname} added to table")
[docs] def launch_training( self, n_epochs, save_best_validation_only=None, transfer_learning_model_cpk=None, trainable_backbone=True, force_dl1_lookup=False, config_file=None, batch_size=64, ): """ Launch the training process for the model. Parameters ---------- n_epochs : int Number of epochs to train the model. If set to 0, training will not proceed. save_best_validation_only : bool, optional Whether to save only the best validation model during training. Overrides the default behavior. transfer_learning_model_cpk : str, optional Path to a checkpoint file for transfer learning. If provided, the model will be initialized from this checkpoint. trainable_backbone : bool, default=True Whether the backbone of the model should be trainable. force_dl1_lookup : bool, default=False Whether to force a lookup for DL1 data. config_file : str, optional Path to a configuration file. If not provided, a new configuration file will be generated. batch_size : int, default=64 Batch size to use during training. Returns ------- None This method does not return any value. It either launches the training process or exits early if conditions are not met. Notes ----- - If the model has already been trained for the maximum number of epochs, training will not proceed. - Automatically handles model versioning and directory creation for saving models. - Generates a configuration file if none is provided. - Supports both local and cluster-based training execution. """ import glob import json import os import numpy as np from astropy.io.misc.hdf5 import read_table_hdf5 self.cluster_configuration.info() if n_epochs == 0: print("🛑 Number of epochs set to 0. Will not train the model.") return n_epoch_trained = self.get_n_epoch_trained() max_training_epochs = self.model_parameters_table["max_training_epochs"][0] base_model_dir = self.model_parameters_table["model_dir"][0] if n_epochs > max_training_epochs - n_epoch_trained: print( f"⚠️ Number of epochs increased from {max_training_epochs} to {n_epochs}" ) self.update_model_manager_parameters_in_index( {"max_training_epochs": n_epochs} ) max_training_epochs = n_epochs n_epochs = max_training_epochs - n_epoch_trained if n_epoch_trained >= max_training_epochs: print( f"🛑 Model already trained for {n_epoch_trained} epochs. Will not train further." ) self.plot_loss() return trained_string = "―" * (n_epoch_trained - 1) trained_spaces = " " * (n_epoch_trained - 1) remaining_string = "·" * (max_training_epochs - n_epoch_trained) to_train_string = "―" * (n_epochs - len(str(n_epoch_trained))) print( f"{trained_string}o{remaining_string} | {n_epoch_trained}/{max_training_epochs} epochs" ) print( f"{trained_spaces}{n_epoch_trained}{to_train_string}> 🚀 Training for {n_epochs} epochs" ) models_dir = np.sort(glob.glob(f"{base_model_dir}/{self.model_nickname}_v*")) load_model = False if len(models_dir) > 0: last_model_dir = Path(models_dir[-1]) size = sum( f.stat().st_size for f in last_model_dir.glob("**/*") if f.is_file() ) model_version = int(models_dir[-1].split("_v")[-1]) if size > 1e6: model_version += 1 model_dir = f"{base_model_dir}/{self.model_nickname}_v{model_version}/" print( f"➡️ Model already exists: will continue training and create {model_dir}" ) _save_best_validation_only = True model_to_load = f"{base_model_dir}/{self.model_nickname}_v{model_version - 1}/ctlearn_model.cpk" load_model = True # os.system(f"mkdir -p {model_dir}") else: model_dir = f"{base_model_dir}/{self.model_nickname}_v{model_version}/" if model_version > 0: model_to_load = f"{base_model_dir}/{self.model_nickname}_v{model_version - 1}/ctlearn_model.cpk" load_model = True print( f"➡️ Model already exists: will continue training and create {model_dir}" ) _save_best_validation_only = True else: print(f"🆕 Model does not exist: will create {model_dir}") _save_best_validation_only = True else: model_version = 0 model_dir = f"{base_model_dir}/{self.model_nickname}_v{model_version}/" print(f"🆕 Model does not exist: will create {model_dir}") os.system(f"mkdir -p {base_model_dir}") _save_best_validation_only = True if save_best_validation_only is not None: _save_best_validation_only = save_best_validation_only if load_model: load_model_string = f"--TrainCTLearnModel.model_type=LoadedModel --LoadedModel.load_model_from={model_to_load} " else: load_model_string = ( "" if transfer_learning_model_cpk is None else f"--TrainCTLearnModel.model_type=LoadedModel --LoadedModel.load_model_from={transfer_learning_model_cpk} " ) training_gamma_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.GAMMA_DIFFUSE).TRAINING.table_path ) signal_patterns = "" for pattern in training_gamma_table["training_gamma_diffuse_patterns"]: signal_patterns += f'--pattern-signal "{pattern}" ' background_patterns = "" if self.model_parameters_table["reco"][0] == "type": training_proton_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.PROTON).TRAINING.table_path ) for pattern in training_proton_table["training_proton_patterns"]: background_patterns += f'--pattern-background "{pattern}" ' background_string = ( f"--background {training_proton_table['training_proton_dir'][0]} " if self.model_parameters_table["reco"][0] == "type" else "" ) channels = ast.literal_eval(self.model_parameters_table["channels"][0]) stereo_mode = "stereo" if self.stereo else "mono" stack_telescope_images = True if self.stereo else False allowed_tels = ast.literal_eval(self.model_parameters_table["telescope_ids"][0]) if config_file is None: config = {} config["TrainCTLearnModel"] = {} config["TrainCTLearnModel"]["save_best_validation_only"] = ( _save_best_validation_only ) config["TrainCTLearnModel"]["n_epochs"] = int(n_epochs) config["TrainCTLearnModel"]["stack_telescope_images"] = ( stack_telescope_images ) config["TrainCTLearnModel"]["reco_tasks"] = [ self.model_parameters_table["reco"][0] ] config["TrainCTLearnModel"]["output_dir"] = model_dir config["TrainCTLearnModel"]["DLImageReader"] = {} config["TrainCTLearnModel"]["DLImageReader"]["allowed_tels"] = allowed_tels config["TrainCTLearnModel"]["DLImageReader"]["min_telescopes"] = int( self.min_telescopes ) config["TrainCTLearnModel"]["DLImageReader"]["force_dl1_lookup"] = ( force_dl1_lookup ) config["TrainCTLearnModel"]["DLImageReader"]["mode"] = stereo_mode config["TrainCTLearnModel"]["DLImageReader"]["channels"] = channels config["LoadedModel"] = {} config["LoadedModel"]["trainable_backbone"] = trainable_backbone config_file = f"{base_model_dir}/train_config{self.model_nickname}_v{model_version}.json" with open(config_file, "w") as file: json.dump(config, file) print(f"Configuration saved to {config_file}") cmd = f"ctlearn-train-model {load_model_string} \ --TrainCTLearnModel.batch_size={batch_size} \ --signal {training_gamma_table['training_gamma_diffuse_dir'][0]} {signal_patterns}\ {background_string} {background_patterns}\ --output {model_dir} \ --config {config_file} \ --overwrite \ --verbose" if self.cluster_configuration.use_cluster: sbatch_file = self.cluster_configuration.write_sbatch_script( f"train_{self.model_nickname}_v{model_version}", cmd, base_model_dir ) os.system(f"sbatch {sbatch_file}") else: print(cmd) os.system(cmd)
[docs] def get_n_epoch_trained(self): """ Calculate the total number of epochs trained across all training logs. This method searches for training log files in the model directory corresponding to the current model nickname, reads them, and sums up the number of epochs recorded in each log. Returns ------- int The total number of epochs trained. """ import glob import pandas as pd training_logs = np.sort( glob.glob( f"{self.model_parameters_table['model_dir'][0]}/{self.model_nickname}*/training_log.csv" ) ) n_epochs = 0 for training_log in training_logs: df = pd.read_csv(training_log) n_epochs += len(df) return n_epochs
[docs] def plot_loss(self): """ Plot the training and validation loss over epochs. This method reads training logs from CSV files, extracts the loss values for training and validation, and plots them against the epochs. If no training logs are found, it prints an error message and exits. Parameters ---------- None Returns ------- None Notes ----- - The method assumes that training logs are stored in CSV files within directories matching the pattern `{model_dir}/{model_nickname}*/training_log.csv`. - The CSV files must contain columns named "loss" and "val_loss". - If only one epoch is available, the losses are displayed as scatter points. """ import glob import matplotlib.pyplot as plt import pandas as pd training_logs = np.sort( glob.glob( f"{self.model_parameters_table['model_dir'][0]}/{self.model_nickname}*/training_log.csv" ) ) losses_train = [] losses_val = [] for training_log in training_logs: df = pd.read_csv(training_log) losses_train = np.concatenate((losses_train, df["loss"].to_numpy())) losses_val = np.concatenate((losses_val, df["val_loss"].to_numpy())) epochs = np.arange(1, len(losses_train) + 1) if len(epochs) == 0: print( f"❌ No training logs found for {self.model_nickname}, start the training to see the loss." ) return if len(epochs) > 1: plt.plot(epochs, losses_train, label="Training", lw=2) plt.plot(epochs, losses_val, label="Validation", ls="--") else: plt.scatter(epochs, losses_train, label="Training", lw=2) plt.scatter(epochs, losses_val, label="Validation", ls="--") plt.title(f"{self.model_parameters_table['reco'][0]} training".title()) plt.xlabel("Epoch") plt.ylabel("Loss") plt.xticks(np.arange(1, len(losses_train) + 1, 2)) plt.legend() plt.show()
[docs] def update_model_manager_parameters_in_index(self, parameters: dict): """ Update the model manager parameters in the HDF5 index file. Parameters ---------- parameters : dict A dictionary containing the parameter names and their new values to update. Notes ----- - This method reads the model parameters table from the HDF5 file, updates the specified parameters, and writes the updated table back to the file. - If a parameter's data type is Unicode string, it is converted to a fixed-length string format ('S256') to handle long strings. - The method also updates the corresponding attributes in the instance's `__dict__`. Raises ------ KeyError If a specified parameter key does not exist in the model table. """ from astropy.io.misc.hdf5 import read_table_hdf5, write_table_hdf5 model_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self).PARAMETERS.table_path ) # model_index = np.where(model_table['model_nickname'] == self.model_nickname)[0][0] print(f"💾 Model {self.model_nickname} index update:") for key, value in parameters.items(): if ( model_table[key].dtype.kind == "U" ): # Check if the dtype is Unicode string model_table[key] = model_table[key].astype( "S256" ) # Convert to 'S256' for long strings model_table[key][0] = value self.__dict__[key] = value print(f"\t➡️ {key} updated to {value}") write_table_hdf5( model_table, self.project_directories.model_index_file, path=IndexTables(self).PARAMETERS.table_path, append=True, overwrite=True, serialize_meta=True, )
[docs] def update_model_manager_testing_data(self, testing_data_sample: DataSample): """ Update the testing data for the model manager with a new data sample. Parameters ---------- testing_data_sample : DataSample The data sample containing testing information to be added or updated. It includes the directory, zenith distance, azimuth, pattern, and particle type. Raises ------ Exception If there is an issue reading or writing the HDF5 file. Notes ----- - If the testing data for the given zenith distance and azimuth already exists, it updates the directory and pattern for that entry. - If no matching entry exists, it adds a new row to the testing data table. - The updated table is saved back to the HDF5 file. """ from astropy.io.misc.hdf5 import read_table_hdf5, write_table_hdf5 testing_dir = testing_data_sample.directory testing_zenith_distance = testing_data_sample.zenith_distance testing_azimuth = testing_data_sample.azimuth testing_pattern = testing_data_sample.pattern particle_type = testing_data_sample.particle_type testing_index_table = IndexTables(self, particle_type).TESTING try: testing_table = read_table_hdf5( self.project_directories.model_index_file, path=testing_index_table.table_path, ) except: testing_table = testing_index_table.default_table # print(f"💾 Model {self.model_nickname} testing data update:") if len(testing_table) == 0: testing_table = testing_index_table.default_table match = np.where( ( testing_table[f"testing_{particle_type.value}_zenith_distances"] == testing_zenith_distance ) & ( testing_table[f"testing_{particle_type.value}_azimuths"] == testing_azimuth ) )[0] if len(match) > 0: testing_table[f"testing_{particle_type.value}_dirs"][match[0]] = testing_dir testing_table[f"testing_{particle_type.value}_patterns"][match[0]] = ( testing_pattern ) else: testing_table.add_row( [testing_dir, testing_zenith_distance, testing_azimuth, testing_pattern] ) write_table_hdf5( testing_table, self.project_directories.model_index_file, path=testing_index_table.table_path, append=True, overwrite=True, serialize_meta=True, )
# print( # f"Testing {particle_type.value} at ({testing_zenith_distance}, {testing_azimuth}) : {testing_dir}/{testing_pattern} updated" # )
[docs] def plot_zenith_azimuth_ranges(self, ax=None, plot_testing_nodes=True): """ Plot the zenith and azimuth ranges on a polar plot. Parameters ---------- ax : matplotlib.axes._axes.Axes, optional The matplotlib axis to plot on. If None, a new polar plot is created. plot_testing_nodes : bool, optional Whether to plot testing nodes (default is True). Notes ----- - The function visualizes the zenith and azimuth ranges for training and testing data. - Training data is represented with filled markers, while testing data is represented with outlined markers. - The plot includes zenith and azimuth ranges as circles or arcs, depending on the data. - The zenith range is displayed in degrees, and the azimuth range is displayed in radians. - The function handles cases where the azimuth range is not defined or contains NaN values. Raises ------ Exception If there is an issue reading the HDF5 tables for training or testing data. See Also -------- astropy.io.misc.hdf5.read_table_hdf5 : Used to read HDF5 tables. matplotlib.pyplot : Used for plotting. """ import astropy.units as u import matplotlib.pyplot as plt from astropy.io.misc.hdf5 import read_table_hdf5 if ax is None: fig, ax = plt.subplots(subplot_kw={"projection": "polar"}) zenith_range = self.validity.zenith_range azimuth_range = self.validity.azimuth_range zenith_min, zenith_max = zenith_range.to(u.deg) if azimuth_range is None: azimuth_min, azimuth_max = 0, 2 * np.pi else: azimuth_min, azimuth_max = azimuth_range.to(u.rad) if zenith_min == zenith_max: if np.isnan(azimuth_min) and np.isnan(azimuth_max): # Plot a circle for this zenith theta = np.linspace(0, 2 * np.pi, 100) * u.rad r = np.full_like(theta, zenith_min).to(u.deg) ax.plot(theta, r, lw=3, zorder=0) elif azimuth_min == azimuth_max: # Plot a point for that position ax.scatter( azimuth_min, zenith_min, s=100, zorder=0, label="Training", color=get_color("ctlearn_1") ) else: # Plot a portion of a circle between the azimuth range at the correct zenith theta = np.linspace(azimuth_min, azimuth_max, 100) r = np.full_like(theta, zenith_min).to(u.deg) ax.plot(theta, r, lw=3, zorder=0) training_gamma_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.GAMMA_DIFFUSE).TRAINING.table_path, ) zeniths = training_gamma_table[ "training_gamma_diffuse_zenith_distances" ] azimuths = training_gamma_table["training_gamma_diffuse_azimuths"].to( u.rad ) for zenith, azimuth in zip(zeniths, azimuths): ax.scatter( azimuth, zenith, s=50, color=get_color("ctlearn_1"), label="Training", ) else: if np.isnan(azimuth_min) and np.isnan(azimuth_max): # Plot the area between the two circles theta = np.linspace(0, 2 * np.pi, 100) * u.rad r1 = np.full_like(theta, zenith_min).to(u.deg) r2 = np.full_like(theta, zenith_max).to(u.deg) ax.fill_between( theta.value, r1.value, r2.value, alpha=0.3, zorder=0, color=get_color("ctlearn_highlight"), ) ax.plot( theta, r1, lw=3, color=get_color("ctlearn_2"), zorder=0, ) ax.plot( theta, r2, lw=3, color=get_color("ctlearn_2"), zorder=0, ) else: theta = np.linspace(azimuth_min, azimuth_max, 100) r1 = np.full_like(theta, zenith_min).to(u.deg).value r2 = np.full_like(theta, zenith_max).to(u.deg).value theta = theta.value ax.fill_between( theta, r1, r2, alpha=0.3, zorder=0, color=get_color("ctlearn_highlight"), ) ax.plot( theta, r1, lw=3, color=get_color("ctlearn_2"), zorder=0, ) ax.plot( theta, r2, lw=3, color=get_color("ctlearn_2"), zorder=0, ) ax.plot( (theta[0], theta[0]), (r1[0], r2[0]), lw=3, color=get_color("ctlearn_2"), zorder=0, ) ax.plot( (theta[-1], theta[-1]), (r1[-1], r2[-1]), lw=3, color=get_color("ctlearn_2"), zorder=0, ) # ax.set_ylim(0, 60) training_gamma_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.GAMMA_DIFFUSE).TRAINING.table_path, ) zeniths = training_gamma_table[ "training_gamma_diffuse_zenith_distances" ] azimuths = training_gamma_table["training_gamma_diffuse_azimuths"].to( u.rad ) for zenith, azimuth in zip(zeniths, azimuths): ax.scatter( azimuth, zenith, s=50, color=get_color("ctlearn_1"), label="Training", ) try: testing_dl1_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.GAMMA_POINT).TESTING.table_path ) zeniths = testing_dl1_table["testing_gamma_point_zenith_distances"] azimuths = testing_dl1_table["testing_gamma_point_azimuths"].to(u.rad) for zenith, azimuth in zip(zeniths, azimuths): if (zenith == np.nan) or (azimuth == np.nan) or not plot_testing_nodes: continue else: ax.scatter( azimuth, zenith, s=50, facecolors="none", edgecolors=get_color("ctlearn_accent_1"), label="Testing DL1", zorder=3, ) except: a = 1 try: mc_dl2_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.GAMMA_POINT).DL2_MC.table_path ) zeniths = mc_dl2_table["testing_DL2_gamma_point_zenith_distances"] azimuths = mc_dl2_table["testing_DL2_gamma_point_azimuths"].to(u.rad) for zenith, azimuth in zip(zeniths, azimuths): if (zenith == np.nan) or (azimuth == np.nan) or not plot_testing_nodes: continue else: ax.scatter( azimuth, zenith, s=50, color=get_color("ctlearn_accent_2"), label="Testing DL2", zorder=2, ) except: a = 1 handles, labels = ax.get_legend_handles_labels() by_label = dict(zip(labels, handles)) ax.legend(by_label.values(), by_label.keys()) ax.set_theta_zero_location("E") ax.set_theta_direction(-1) ax.set_rlabel_position(-30) # print(zenith_max.value) # ax.set_ylim((0, np.max(60, int(zenith_max.value)))) ax.set_yticks(np.arange(10, 61, 10)) ax.set_yticklabels(["", "", "30°", "", "", "60°"], fontsize=10) ax.set_xlabel("Azimuth [deg]", fontsize=10) ax.set_title("Zenith and Azimuth Ranges", pad=30) plt.tight_layout() if ax is None: plt.show()
[docs] def plot_training_nodes(self): """ Plot the training nodes for gamma and proton events in a polar coordinate system. This method visualizes the training nodes for gamma and proton events using their zenith and azimuth angles. The plot is displayed in a polar coordinate system, with specific styling for gamma and proton events. Parameters ---------- None Notes ----- - Gamma training nodes are read from the HDF5 file at the path `<model_nickname>/training/gamma_diffuse`. - Proton training nodes are read from the HDF5 file at the path `<model_nickname>/training/proton` if the model parameter `reco` is set to "type". - If zenith or azimuth values are undefined (NaN), those nodes are skipped. - The plot includes custom styling for gamma and proton nodes, with distinct colors and markers. Warnings -------- - If no valid zenith or azimuth values are found for gamma or proton nodes, a message is printed to indicate that the corresponding training nodes cannot be shown. See Also -------- astropy.io.misc.hdf5.read_table_hdf5 : Used to read the training data tables. matplotlib.pyplot.subplots : Used to create the polar plot. """ import astropy.units as u import matplotlib.pyplot as plt from astropy.io.misc.hdf5 import read_table_hdf5 fig, ax = plt.subplots(subplot_kw={"projection": "polar"}) training_gamma_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.GAMMA_DIFFUSE).TRAINING.table_path ) zeniths = training_gamma_table["training_gamma_diffuse_zenith_distances"] azimuths = training_gamma_table["training_gamma_diffuse_azimuths"].to(u.rad) i = 0 for zenith, azimuth in zip(zeniths, azimuths): if (zenith == np.nan) or (azimuth == np.nan): continue else: ax.scatter( azimuth, zenith, s=50, color=get_color("ctlearn_1"), label="Gammas", zorder=10, ) i += 1 if i == 0: print( "Training nodes for gammas cannot be shown because the zenith or azimuth are not defined." ) if self.model_parameters_table["reco"][0] == "type": training_proton_table = read_table_hdf5( self.project_directories.model_index_file, path=IndexTables(self, ParticleType.PROTON).TRAINING.table_path ) zeniths = training_proton_table["training_proton_zenith_distances"] azimuths = training_proton_table["training_proton_azimuths"].to(u.rad) i = 0 for zenith, azimuth in zip(zeniths, azimuths): if (zenith == np.nan) or (azimuth == np.nan): continue else: ax.scatter( azimuth, zenith, label="Protons", edgecolor=get_color("ctlearn_accent_1"), facecolors="w", zorder=1, s=100, lw=2, ) i += 1 if i == 0: print( "Training nodes for protons cannot be shown because the zenith or azimuth are not defined." ) ax.set_theta_zero_location("E") ax.set_theta_direction(-1) ax.set_rlabel_position(-30) ax.set_ylim(0, 60) ax.set_yticks(np.arange(10, 61, 10)) ax.set_yticklabels(["", "", "30°", "", "", "60°"], fontsize=10) ax.set_xlabel("Azimuth [deg]", fontsize=10) handles, labels = ax.get_legend_handles_labels() by_label = dict(zip(labels, handles)) ax.legend(by_label.values(), by_label.keys()) ax.set_title("Training nodes", pad=30) plt.tight_layout() plt.show()
[docs] def remove_row_from_table(self, table_path: str, row_index: int): remove_row_from_table_utils(self.project_directories.model_index_file, table_path, row_index)
[docs] class ModelRangeOfValidity: """ Class to represent the range of validity for a CTLearn model. This class extracts and stores the ranges of zenith, azimuth, energy, and NSB values from the training gamma data of a CTLearn model. It also provides a method to check if given parameters fall within these ranges. Parameters ---------- model_manager : CTLearnModelManager An instance of CTLearnModelManager containing the model index file and model nickname. Attributes ---------- zenith_range : astropy.units.Quantity The range of zenith distances in the training gamma data. azimuth_range : astropy.units.Quantity The range of azimuths in the training gamma data. energy_range : astropy.units.Quantity The range of energies in the training gamma data. nsb_range : astropy.units.Quantity The range of NSB values in the training gamma data. Methods ------- matches(**kwargs) Check if the given parameters fall within the model's range of validity. """ def __init__(self, model_manager: CTLearnModelManager): """ Initialize the instance with model parameters from the provided CTLearnModelManager. Parameters ---------- model_manager : CTLearnModelManager An instance of CTLearnModelManager containing the model index file and model nickname to retrieve training data. Attributes ---------- zenith_range : astropy.units.Quantity The range of zenith distances in the training gamma diffuse data. azimuth_range : astropy.units.Quantity The range of azimuth angles in the training gamma diffuse data. energy_range : astropy.units.Quantity The range of energy values in the training gamma diffuse data. nsb_range : astropy.units.Quantity The range of night sky background (NSB) values in the training gamma diffuse data. """ from astropy.io.misc.hdf5 import read_table_hdf5 training_gamma_table = read_table_hdf5( model_manager.project_directories.model_index_file, path=IndexTables(model_manager, ParticleType.GAMMA_DIFFUSE).TRAINING.table_path, ) # training_proton_table = read_table_hdf5(model_manager.model_index_file, path=f'{model_manager.model_nickname}/training/proton') training_gamma_zeniths = training_gamma_table[ "training_gamma_diffuse_zenith_distances" ] self.zenith_range = [ min(training_gamma_zeniths).value, max(training_gamma_zeniths).value, ] * training_gamma_zeniths.unit training_gamma_azimuths = training_gamma_table[ "training_gamma_diffuse_azimuths" ] self.azimuth_range = [ min(training_gamma_azimuths.value), max(training_gamma_azimuths).value, ] * training_gamma_azimuths.unit training_gamma_energies_mins = training_gamma_table[ "training_gamma_diffuse_energy_min" ] training_gamma_energies_maxs = training_gamma_table[ "training_gamma_diffuse_energy_max" ] taining_gamma_energies = np.concatenate( (training_gamma_energies_mins, training_gamma_energies_maxs) ) self.energy_range = [ min(taining_gamma_energies).value, max(taining_gamma_energies).value, ] * taining_gamma_energies.unit training_gamma_nsbs_mins = training_gamma_table[ "training_gamma_diffuse_nsb_min" ] training_gamma_nsbs_maxs = training_gamma_table[ "training_gamma_diffuse_nsb_max" ] taining_gamma_nsbs = np.concatenate( (training_gamma_nsbs_mins, training_gamma_nsbs_maxs) ) self.nsb_range = [ min(taining_gamma_nsbs).value, max(taining_gamma_nsbs).value, ] * taining_gamma_nsbs.unit
[docs] def matches(self, **kwargs): """ Check if the given parameters match the defined ranges. Parameters ---------- **kwargs : dict Keyword arguments representing the parameters to check. Supported keys are: - "zenith": float, the zenith angle to check. - "azimuth": float, the azimuth angle to check. - "energy": float, the energy value to check. - "nsb": float, the night sky background (NSB) value to check. Returns ------- bool True if all provided parameters fall within their respective ranges, False otherwise. Notes ----- - If a range (e.g., `azimuth_range`, `energy_range`, or `nsb_range`) is `None`, the corresponding parameter is not checked. - The `zenith_range` is always checked if the "zenith" key is provided. """ for key, value in kwargs.items(): if key == "zenith": if not (self.zenith_range[0] <= value <= self.zenith_range[1]): return False elif key == "azimuth": if self.azimuth_range is not None and not ( self.azimuth_range[0] <= value <= self.azimuth_range[1] ): return False elif key == "energy": if self.energy_range is not None and not ( self.energy_range[0] <= value <= self.energy_range[1] ): return False elif key == "nsb": if self.nsb_range is not None and not ( self.nsb_range[0] <= value <= self.nsb_range[1] ): return False return True
class IndexTables: def __init__(self, model_manager: CTLearnModelManager, particle_type: ParticleType=None): self.model_manager = model_manager self.particle_type = particle_type if self.particle_type is not None: self.DL2_MC = self.IndexTable( QTable( names=[ f"testing_DL2_{self.particle_type.value}_files", f"testing_DL2_{self.particle_type.value}_zenith_distances", f"testing_DL2_{self.particle_type.value}_azimuths", "merged", ], dtype=["S256", float, float, bool], units=[None, "deg", "deg", None], ), f"{self.model_manager.model_nickname}/DL2/MC/{particle_type.value}" ) self.TRAINING = self.IndexTable( QTable( names=[ f"training_{particle_type.value}_dir", f"training_{particle_type.value}_patterns", f"training_{particle_type.value}_zenith_distances", f"training_{particle_type.value}_azimuths", f"training_{particle_type.value}_energy_min", f"training_{particle_type.value}_energy_max", f"training_{particle_type.value}_nsb_min", f"training_{particle_type.value}_nsb_max", ], dtype=[ "S256", "S256", float, float, float, float, float, float, ], units=[None, None, "deg", "deg", "TeV", "TeV", "Hz", "Hz"], ), f"{self.model_manager.model_nickname}/training/{particle_type.value}") self.TESTING = self.IndexTable( QTable( names=[ f"testing_{particle_type.value}_dirs", f"testing_{particle_type.value}_zenith_distances", f"testing_{particle_type.value}_azimuths", f"testing_{particle_type.value}_patterns", ], dtype=["S256", float, float, "S256"], units=[None, "deg", "deg", None], ), f"{self.model_manager.model_nickname}/testing/{particle_type.value}", ) self.PARAMETERS = self.IndexTable( QTable( names=[ "model_nickname", "model_dir", "reco", "channels", "telescope_names", "telescope_ids", "notes", "max_training_epochs", "min_telescopes", "stereo", ], dtype=[ "S256", "S256", "S256", "S256", "S256", "S256", "S256", int, int, bool, ], ), f"{self.model_manager.model_nickname}/parameters" ) self.IRF = self.IndexTable( QTable( names=[ "config", "cuts_file", "irf_file", "benckmark_file", "zenith", "azimuth", ], dtype=["S256", "S256", "S256", "S256", float, float], units=[None, None, None, None, "deg", "deg"], ), f"{self.model_manager.model_nickname}/IRF" ) self.DL2_DATA = self.IndexTable( QTable( names=["DL2_files", "DL2_zenith_distances", "DL2_azimuths"], dtype=["S256", float, float], ), f"{self.model_manager.model_nickname}/DL2/Data" ) class IndexTable: def __init__(self, default_table: QTable, table_path: str): self.default_table = default_table self.table_path = table_path