ctlearn_manager.model_manager
CTLearnModelManager Class.
This class is designed to manage CTLearn models, providing functionalities for initializing, saving, loading, training, and updating model parameters. It also includes methods for handling training data, testing data, DL2 data, and IRF data.
- class ctlearn_manager.model_manager.CTLearnModelManager(model_parameters, project_directories, load=False, cluster_configuration=<ctlearn_manager.utils.utils.ClusterConfiguration object>)[source]
Bases:
objectCTLearnModelManager class for managing CTLearn models.
This class provides methods for initializing, saving, loading, and training CTLearn models. It also includes methods for updating and retrieving model parameters, training data, testing data, DL2 data, and IRF data.
- model_index_file
Path to the model index file.
- Type:
str
- model_nickname
Nickname of the model.
- Type:
str
- model_parameters_table
Table containing model parameters.
- Type:
astropy.table.Table
- validity
Range of validity for the model.
- Type:
- stereo
Indicates if the model uses stereo mode.
- Type:
bool
- telescope_ids
List of telescope IDs used in the model.
- Type:
list
- telescope_names
List of telescope names used in the model.
- Type:
list
- cluster_configuration
Configuration for the cluster.
- Type:
- __init__(model_parameters, MODEL_INDEX_FILE, load=False, cluster_configuration=ClusterConfiguration())[source]
Initializes the ModelManager instance.
- save_to_index(model_parameters)[source]
Save model parameters and training samples to an HDF5 index file.
- launch_training(n_epochs, transfer_learning_model_cpk=None, frozen_backbone=False, config_file=None)[source]
Launches the training process for the model.
- get_n_epoch_trained()[source]
Calculate the total number of epochs trained by summing the lengths of all training logs.
- update_model_manager_parameters_in_index(parameters)[source]
Update the model manager parameters in the HDF5 index file.
- Parameters:
parameters (dict)
- update_model_manager_testing_data(testing_gamma_dirs, testing_proton_dirs, testing_gamma_zenith_distances, testing_gamma_azimuths, testing_proton_zenith_distances, testing_proton_azimuths, testing_gamma_patterns, testing_proton_patterns)[source]
Update the model manager’s testing data for gamma and proton events.
- Parameters:
testing_data_sample (DataSample)
- update_model_manager_DL2_MC_files(testing_DL2_gamma_files, testing_DL2_proton_files, testing_DL2_gamma_zenith_distances, testing_DL2_gamma_azimuths, testing_DL2_proton_zenith_distances, testing_DL2_proton_azimuths)
Update the DL2 MC files for gamma and proton testing data in the model manager.
- update_model_manager_DL2_data_files(DL2_files, DL2_zenith_distances, DL2_azimuths)
Update the DL2 data files for the model manager.
- update_merged_DL2_MC_files(testing_DL2_zenith_distance, testing_DL2_azimuth, testing_DL2_gamma_merged_file=None, testing_DL2_proton_merged_file=None)
Update the merged DL2 MC files for gamma and proton data.
- update_model_manager_IRF_data(config, cuts_file, irf_file, bencmark_file, zenith, azimuth)
Update the IRF (Instrument Response Function) data for the model manager.
- get_IRF_data(zenith, azimuth)
Retrieve the Instrument Response Function (IRF) data for a given zenith and azimuth.
- get_closest_IRF_data(zenith, azimuth)
Retrieve the closest Instrument Response Function (IRF) data based on the given zenith and azimuth angles.
- get_DL2_MC_files(zenith, azimuth)
Retrieve DL2 Monte Carlo (MC) files for given zenith and azimuth angles.
Initialize the ModelManager class.
- Parameters:
model_parameters (dict) – Dictionary containing the parameters for the model. Must include at least the key “model_nickname” if a specific nickname is desired.
MODEL_INDEX_FILE (str) – Path to the HDF5 file containing the model index.
load (bool, optional) – If True, the model is loaded from the index file. If False, the model parameters are saved to the index file. Default is False.
cluster_configuration (ClusterConfiguration, optional) – Configuration object for the cluster. Default is an instance of ClusterConfiguration.
project_directories (CTLMDirectories)
- Raises:
ValueError – If the model is of type “reco” and the required training patterns for gamma diffuse or proton are missing.
ValueError – If stereo mode is enabled but fewer than 2 telescopes are provided.
ValueError – If the model is of type “reco” with “cameradirection” and stereo mode is enabled, as this combination is not supported.
- save_to_index(model_parameters)[source]
Save model parameters and training samples to an HDF5 index file.
- Parameters:
model_parameters (dict) –
A dictionary containing model parameters and training sample details. Expected keys include: - “model_dir” (str): Absolute path to the model directory. - “reco” (str, optional): Reconstruction type, one of
[‘type’, ‘energy’, ‘cameradirection’, ‘skydirection’]. Defaults to “default_reco”.
”channels” (list of str, optional): List of channel names. Defaults to [“cleaned_image”, “cleaned_relative_peak_time”].
”telescope_names” (list of str, optional): Names of telescopes. Defaults to an empty list.
”telescope_ids” (list of int, optional): IDs of telescopes. Defaults to an empty list.
”notes” (str, optional): Notes about the model. Defaults to an empty string.
”max_training_epochs” (int, optional): Maximum number of training epochs. Defaults to 10.
”min_telescopes” (int, optional): Minimum number of telescopes. Defaults to 1.
”stereo” (bool, optional): Whether the model uses stereo mode. Defaults to True if min_telescopes >= 2, otherwise False.
”training_samples” (list, optional): List of training sample objects. Each object must have the following attributes: - particle_type (enum): Type of particle. - directory (str): Directory of the training sample. - pattern (str): Pattern of the training sample. - zenith_distance (float): Zenith distance in degrees. - azimuth (float): Azimuth in degrees. - energy_range (tuple of float): Minimum and maximum energy in TeV. - nsb_range (tuple of float): Minimum and maximum NSB in Hz.
- Raises:
ValueError – If the “model_dir” is not an absolute path.
AssertionError – If any of the following conditions are not met: - Telescope names and IDs have the same length. - telescope_ids is a 1-dimensional array. - telescope_names is a 1-dimensional array. - channels is a 1-dimensional array. - reco is one of [‘type’, ‘energy’, ‘cameradirection’, ‘skydirection’]. - max_training_epochs is an integer. - min_telescopes is an integer.
Notes
This method creates or updates an HDF5 file to store model parameters and training sample details. If the model nickname already exists in the index, it will not overwrite the existing entry. Training sample details are stored under paths specific to the particle type.
- launch_training(n_epochs, save_best_validation_only=None, transfer_learning_model_cpk=None, trainable_backbone=True, force_dl1_lookup=False, config_file=None, batch_size=64)[source]
Launch the training process for the model.
- Parameters:
n_epochs (int) – Number of epochs to train the model. If set to 0, training will not proceed.
save_best_validation_only (bool, optional) – Whether to save only the best validation model during training. Overrides the default behavior.
transfer_learning_model_cpk (str, optional) – Path to a checkpoint file for transfer learning. If provided, the model will be initialized from this checkpoint.
trainable_backbone (bool, default=True) – Whether the backbone of the model should be trainable.
force_dl1_lookup (bool, default=False) – Whether to force a lookup for DL1 data.
config_file (str, optional) – Path to a configuration file. If not provided, a new configuration file will be generated.
batch_size (int, default=64) – Batch size to use during training.
- Returns:
This method does not return any value. It either launches the training process or exits early if conditions are not met.
- Return type:
None
Notes
If the model has already been trained for the maximum number of epochs, training will not proceed.
Automatically handles model versioning and directory creation for saving models.
Generates a configuration file if none is provided.
Supports both local and cluster-based training execution.
- get_n_epoch_trained()[source]
Calculate the total number of epochs trained across all training logs.
This method searches for training log files in the model directory corresponding to the current model nickname, reads them, and sums up the number of epochs recorded in each log.
- Returns:
The total number of epochs trained.
- Return type:
int
- plot_loss()[source]
Plot the training and validation loss over epochs.
This method reads training logs from CSV files, extracts the loss values for training and validation, and plots them against the epochs. If no training logs are found, it prints an error message and exits.
- Parameters:
None
- Return type:
None
Notes
The method assumes that training logs are stored in CSV files within directories matching the pattern {model_dir}/{model_nickname}*/training_log.csv.
The CSV files must contain columns named “loss” and “val_loss”.
If only one epoch is available, the losses are displayed as scatter points.
- update_model_manager_parameters_in_index(parameters)[source]
Update the model manager parameters in the HDF5 index file.
- Parameters:
parameters (dict) – A dictionary containing the parameter names and their new values to update.
Notes
This method reads the model parameters table from the HDF5 file, updates the specified parameters, and writes the updated table back to the file.
If a parameter’s data type is Unicode string, it is converted to a fixed-length string format (‘S256’) to handle long strings.
The method also updates the corresponding attributes in the instance’s __dict__.
- Raises:
KeyError – If a specified parameter key does not exist in the model table.
- Parameters:
parameters (dict)
- update_model_manager_testing_data(testing_data_sample)[source]
Update the testing data for the model manager with a new data sample.
- Parameters:
testing_data_sample (DataSample) – The data sample containing testing information to be added or updated. It includes the directory, zenith distance, azimuth, pattern, and particle type.
- Raises:
Exception – If there is an issue reading or writing the HDF5 file.
Notes
If the testing data for the given zenith distance and azimuth already exists, it updates the directory and pattern for that entry.
If no matching entry exists, it adds a new row to the testing data table.
The updated table is saved back to the HDF5 file.
- plot_zenith_azimuth_ranges(ax=None, plot_testing_nodes=True)[source]
Plot the zenith and azimuth ranges on a polar plot.
- Parameters:
ax (matplotlib.axes._axes.Axes, optional) – The matplotlib axis to plot on. If None, a new polar plot is created.
plot_testing_nodes (bool, optional) – Whether to plot testing nodes (default is True).
Notes
The function visualizes the zenith and azimuth ranges for training and testing data.
Training data is represented with filled markers, while testing data is represented with outlined markers.
The plot includes zenith and azimuth ranges as circles or arcs, depending on the data.
The zenith range is displayed in degrees, and the azimuth range is displayed in radians.
The function handles cases where the azimuth range is not defined or contains NaN values.
- Raises:
Exception – If there is an issue reading the HDF5 tables for training or testing data.
See also
astropy.io.misc.hdf5.read_table_hdf5Used to read HDF5 tables.
matplotlib.pyplotUsed for plotting.
- plot_training_nodes()[source]
Plot the training nodes for gamma and proton events in a polar coordinate system.
This method visualizes the training nodes for gamma and proton events using their zenith and azimuth angles. The plot is displayed in a polar coordinate system, with specific styling for gamma and proton events.
- Parameters:
None
Notes
Gamma training nodes are read from the HDF5 file at the path <model_nickname>/training/gamma_diffuse.
Proton training nodes are read from the HDF5 file at the path <model_nickname>/training/proton if the model parameter reco is set to “type”.
If zenith or azimuth values are undefined (NaN), those nodes are skipped.
The plot includes custom styling for gamma and proton nodes, with distinct colors and markers.
Warning
If no valid zenith or azimuth values are found for gamma or proton nodes, a message is printed to indicate that the corresponding training nodes cannot be shown.
See also
astropy.io.misc.hdf5.read_table_hdf5Used to read the training data tables.
matplotlib.pyplot.subplotsUsed to create the polar plot.
- class ctlearn_manager.model_manager.DataSample(directory, pattern, particle_type=None, zenith_distance=<Quantity nan deg>, azimuth=<Quantity nan deg>, energy_range=<Quantity [nan, nan] TeV>, nsb_range=<Quantity [nan, nan] Hz>)[source]
Bases:
objectA class to represent a training sample for CTLearn. :param directory: The directory where training data is stored. :type directory: str :param pattern: The pattern to match training files. :type pattern: str :param zenith_distance: The zenith distance of the training sample. :type zenith_distance: astropy.units.Quantity :param azimuth: The azimuth of the training sample. :type azimuth: astropy.units.Quantity :param energy_range: The energy range of the training sample. :type energy_range: list of astropy.units.Quantity :param nsb_range: The NSB (Night Sky Background) range of the training sample. :type nsb_range: list of astropy.units.Quantity
Initialize the ModelManager. :param directory: The directory where training data is stored. :type directory: str :param pattern: The pattern to match training files. :type pattern: str :param zenith_distance: The zenith distance for training data, defaults to NaN degrees. :type zenith_distance: astropy.units.Quantity :param azimuth: The azimuth for training data, defaults to NaN degrees. :type azimuth: astropy.units.Quantity :param energy_range: The energy range for training data, defaults to [NaN, NaN] TeV. :type energy_range: list of astropy.units.Quantity :param nsb_range: The NSB range for training data, defaults to [NaN, NaN] Hz. :type nsb_range: list of astropy.units.Quantity
- Parameters:
directory (str)
pattern (str)
particle_type (ParticleType | None)
- u = <module 'astropy.units' from '/home/docs/checkouts/readthedocs.org/user_builds/ctlearn-manager/envs/latest/lib/python3.13/site-packages/astropy/units/__init__.py'>
- class ctlearn_manager.model_manager.ModelRangeOfValidity(model_manager)[source]
Bases:
objectClass to represent the range of validity for a CTLearn model.
This class extracts and stores the ranges of zenith, azimuth, energy, and NSB values from the training gamma data of a CTLearn model. It also provides a method to check if given parameters fall within these ranges.
- Parameters:
model_manager (CTLearnModelManager) – An instance of CTLearnModelManager containing the model index file and model nickname.
- zenith_range
The range of zenith distances in the training gamma data.
- Type:
astropy.units.Quantity
- azimuth_range
The range of azimuths in the training gamma data.
- Type:
astropy.units.Quantity
- energy_range
The range of energies in the training gamma data.
- Type:
astropy.units.Quantity
- nsb_range
The range of NSB values in the training gamma data.
- Type:
astropy.units.Quantity
Initialize the instance with model parameters from the provided CTLearnModelManager.
- Parameters:
model_manager (CTLearnModelManager) – An instance of CTLearnModelManager containing the model index file and model nickname to retrieve training data.
- zenith_range
The range of zenith distances in the training gamma diffuse data.
- Type:
astropy.units.Quantity
- azimuth_range
The range of azimuth angles in the training gamma diffuse data.
- Type:
astropy.units.Quantity
- energy_range
The range of energy values in the training gamma diffuse data.
- Type:
astropy.units.Quantity
- nsb_range
The range of night sky background (NSB) values in the training gamma diffuse data.
- Type:
astropy.units.Quantity
- matches(**kwargs)[source]
Check if the given parameters match the defined ranges.
- Parameters:
**kwargs (dict) – Keyword arguments representing the parameters to check. Supported keys are: - “zenith”: float, the zenith angle to check. - “azimuth”: float, the azimuth angle to check. - “energy”: float, the energy value to check. - “nsb”: float, the night sky background (NSB) value to check.
- Returns:
True if all provided parameters fall within their respective ranges, False otherwise.
- Return type:
bool
Notes
If a range (e.g., azimuth_range, energy_range, or nsb_range) is None, the corresponding parameter is not checked.
The zenith_range is always checked if the “zenith” key is provided.