Source code for scripts.training.shared_training_infrastructure

"""Shared training utilities for TE run identity, artifacts, and registries."""

from __future__ import annotations

# Import Python Utilities
import hashlib
import os
from copy import deepcopy
from dataclasses import asdict
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any

# Import PyTorch Utilities
import torch
import torch.nn as nn

# Import YAML Utilities
import yaml

# Import Project Utilities
from scripts.datasets.transmission_error_dataset import load_dataset_processing_config
from scripts.datasets.transmission_error_dataset import resolve_dataset_selection
from scripts.datasets.transmission_error_dataset import resolve_project_relative_path
from scripts.datasets.transmission_error_dataset import resolve_dataset_schema
from scripts.models.model_factory import create_model
from scripts.tooling import repository_path_support
from scripts.training.transmission_error_datamodule import NormalizationStatistics
from scripts.training.transmission_error_datamodule import TransmissionErrorDataModule
from scripts.training.transmission_error_regression_module import TransmissionErrorRegressionModule

PROJECT_PATH = Path(os.path.abspath(__file__)).parents[2]
DEFAULT_CONFIG_PATH = PROJECT_PATH / "config" / "training" / "feedforward" / "presets" / "baseline.yaml"
DEFAULT_RUNTIME_CONFIG_DICTIONARY = {
    "accelerator": "auto",
    "devices": "auto",
    "precision": "32",
    "benchmark": True,
    "use_non_blocking_transfer": True,
}
DEFAULT_DATALOADER_WORKER_CAP = 8
ENVIRONMENT_DATALOADER_WORKERS_KEY = "PINNS_DATALOADER_WORKERS"
ENVIRONMENT_DATALOADER_WORKER_CAP_KEY = "PINNS_DATALOADER_WORKER_CAP"
TRAINING_RUN_TIMESTAMP_FORMAT = "%Y-%m-%d-%H-%M-%S"
OUTPUT_PATH = PROJECT_PATH / "output"
VALIDATION_OUTPUT_ROOT = OUTPUT_PATH / "validation_checks"
SMOKE_TEST_OUTPUT_ROOT = OUTPUT_PATH / "smoke_tests"
REGISTRY_OUTPUT_ROOT = OUTPUT_PATH / "registries"
FAMILY_REGISTRY_OUTPUT_ROOT = REGISTRY_OUTPUT_ROOT / "families"
PROGRAM_REGISTRY_OUTPUT_ROOT = REGISTRY_OUTPUT_ROOT / "program"
RUN_OUTPUT_ARTIFACT_KIND = "training_run"
VALIDATION_OUTPUT_ARTIFACT_KIND = "validation_check"
SMOKE_TEST_OUTPUT_ARTIFACT_KIND = "smoke_test"
COMMON_TRAINING_CONFIG_FILENAME = "training_config.yaml"
COMMON_METRICS_FILENAME = "metrics_summary.yaml"
COMMON_VALIDATION_FILENAME = "validation_summary.yaml"
COMMON_SMOKE_TEST_FILENAME = "smoke_test_summary.yaml"
COMMON_RUN_METADATA_FILENAME = "run_metadata.yaml"
COMMON_RUN_REPORT_FILENAME = "training_test_report.md"
FAMILY_LEADERBOARD_FILENAME = "leaderboard.yaml"
FAMILY_BEST_FILENAME = "latest_family_best.yaml"
PROGRAM_BEST_FILENAME = "current_best_solution.yaml"
SAFE_REPORT_MAX_PATH_LENGTH = 220
SAFE_REPORT_MAX_FILENAME_LENGTH = 120
SAFE_REPORT_HASH_HEX_LENGTH = 8
SAFE_REPORT_COMPACT_RUN_TOKEN_MAX_LENGTH = 48
GLOBAL_TRAINING_VARIANT = "global"
FORWARD_ONLY_TRAINING_VARIANT = "Fw"
BACKWARD_ONLY_TRAINING_VARIANT = "Bw"
SELECTION_POLICY_DICTIONARY = {
    "primary_metric": "test_mae",
    "first_tie_breaker": "test_rmse",
    "second_tie_breaker": "val_mae",
    "third_tie_breaker": "trainable_parameter_count",
    "direction": "minimize",
}

def resolve_runtime_project_path() -> Path:

    """Resolve the active repository root for the current runtime context."""

    # Prefer the Live Working Directory When It Already Looks Like The Repo Root
    working_directory = Path(os.getcwd())
    required_entry_name_list = ["scripts", "config", "doc"]
    if all((working_directory / entry_name).exists() for entry_name in required_entry_name_list):
        return working_directory

    return PROJECT_PATH


def resolve_runtime_project_relative_path(path_value: str | Path) -> Path:

    """Resolve one repository-relative path against the active runtime root."""

    return repository_path_support.resolve_repository_path(
        path_value=path_value,
        repository_root=resolve_runtime_project_path(),
        allow_absolute=True,
    )

[docs] @dataclass class ExperimentIdentity: """Logical experiment identity resolved from a training configuration.""" model_family: str model_type: str run_name: str
[docs] @dataclass class ModelParameterSummary: """Trainable, frozen, and total parameter counts for one backbone.""" backbone_name: str trainable_parameter_count: int frozen_parameter_count: int total_parameter_count: int
[docs] @dataclass class RunArtifactIdentity: """Physical artifact identity used for immutable output folders.""" artifact_kind: str model_family: str run_name: str run_instance_id: str output_directory: Path
[docs] def load_training_config(config_path: str | Path = DEFAULT_CONFIG_PATH) -> dict[str, Any]: """Load and validate a YAML training configuration. Args: config_path: Absolute or project-relative configuration path. Returns: dict[str, Any]: Parsed training configuration dictionary. """ # Resolve Config Path -> Prefer Live Working Directory For Relative Paths candidate_config_path = Path(config_path) if candidate_config_path.is_absolute(): if candidate_config_path.exists(): resolved_config_path = candidate_config_path else: resolved_config_path = resolve_project_relative_path(candidate_config_path) else: working_directory_config_path = Path(os.path.abspath(str(candidate_config_path))) if working_directory_config_path.exists(): resolved_config_path = working_directory_config_path else: resolved_config_path = resolve_project_relative_path(candidate_config_path) assert resolved_config_path.exists(), f"Training Config Path does not exist | {resolved_config_path}" # Load YAML Configuration with resolved_config_path.open("r", encoding="utf-8") as config_file: training_config = yaml.safe_load(config_file) # Validate Configuration Type assert isinstance(training_config, dict), "Training Config must be a dictionary" return training_config
def clone_training_config(training_config: dict[str, Any]) -> dict[str, Any]: """ Clone Training Config """ return deepcopy(training_config) def apply_dataset_override( training_config: dict[str, Any], dataset_name: str | None, ) -> dict[str, Any]: """Clone a training config and apply one optional dataset selector.""" resolved_training_config = clone_training_config(training_config) if dataset_name is None: return resolved_training_config from scripts.datasets.transmission_error_dataset import normalize_dataset_name normalized_dataset_name = normalize_dataset_name(dataset_name) resolved_training_config.setdefault("dataset", {})["name"] = normalized_dataset_name if normalized_dataset_name == "polished_dataset": resolved_training_config.setdefault("model", {})["input_size"] = "auto" return resolved_training_config def resolve_training_dataset_schema(training_config: dict[str, Any]): """Resolve the effective dataset schema from a training configuration.""" explicit_dataset_name = training_config.get("dataset", {}).get("name") if explicit_dataset_name is not None: return resolve_dataset_schema(explicit_dataset_name) dataset_config_path = training_config.get("paths", {}).get("dataset_config_path") if dataset_config_path is not None: dataset_processing_config = load_dataset_processing_config(dataset_config_path) selected_dataset_name, _ = resolve_dataset_selection(dataset_processing_config) return resolve_dataset_schema(selected_dataset_name) return resolve_dataset_schema() def sanitize_name(name: str) -> str: """ Sanitize Name """ sanitized_name = "".join(character.lower() if character.isalnum() else "_" for character in name.strip()) sanitized_name = sanitized_name.strip("_") return sanitized_name or "run" def build_compact_hashed_name_token( raw_name: str, maximum_token_length: int = SAFE_REPORT_COMPACT_RUN_TOKEN_MAX_LENGTH, ) -> str: """Build a compact deterministic token from a potentially long raw name.""" sanitized_name = sanitize_name(raw_name) if len(sanitized_name) <= maximum_token_length: return sanitized_name name_hash = hashlib.sha1(sanitized_name.encode("utf-8")).hexdigest()[:SAFE_REPORT_HASH_HEX_LENGTH] minimum_prefix_length = 4 reserved_separator_count = 1 prefix_budget = max( minimum_prefix_length, maximum_token_length - len(name_hash) - reserved_separator_count, ) compact_prefix = sanitized_name[:prefix_budget].rstrip("_") if not compact_prefix: compact_prefix = sanitized_name[:minimum_prefix_length] return f"{compact_prefix}_{name_hash}" def build_safe_validation_report_filename( report_root: Path, timestamp_string: str, model_family: str, output_run_name: str, report_suffix: str, ) -> str: """Build a deterministic validation-report filename within safe limits.""" sanitized_model_family = sanitize_name(model_family) sanitized_output_run_name = sanitize_name(output_run_name) maximum_filename_length = min( SAFE_REPORT_MAX_FILENAME_LENGTH, SAFE_REPORT_MAX_PATH_LENGTH - len(str(report_root)) - 1, ) verbose_filename = ( f"{timestamp_string}_{sanitized_model_family}_{sanitized_output_run_name}_{report_suffix}" ) verbose_report_path = report_root / verbose_filename if ( len(verbose_filename) <= maximum_filename_length and len(str(verbose_report_path)) <= SAFE_REPORT_MAX_PATH_LENGTH ): return verbose_filename fixed_character_count = len(timestamp_string) + len(report_suffix) + 3 combined_token_budget = maximum_filename_length - fixed_character_count minimum_family_token_budget = 12 minimum_run_token_budget = 16 preferred_family_token_budget = min(24, max(minimum_family_token_budget, combined_token_budget // 3)) preferred_run_token_budget = max( minimum_run_token_budget, combined_token_budget - preferred_family_token_budget, ) compact_model_family_token = build_compact_hashed_name_token( sanitized_model_family, maximum_token_length=preferred_family_token_budget, ) compact_run_token = build_compact_hashed_name_token( sanitized_output_run_name, maximum_token_length=preferred_run_token_budget, ) compact_filename = ( f"{timestamp_string}_{compact_model_family_token}_{compact_run_token}_{report_suffix}" ) compact_report_path = report_root / compact_filename if len(compact_filename) > maximum_filename_length or len(str(compact_report_path)) > SAFE_REPORT_MAX_PATH_LENGTH: fallback_family_token = build_compact_hashed_name_token( sanitized_model_family, maximum_token_length=minimum_family_token_budget, ) fallback_run_token_budget = max( minimum_run_token_budget, combined_token_budget - len(fallback_family_token), ) fallback_run_token = build_compact_hashed_name_token( sanitized_output_run_name, maximum_token_length=fallback_run_token_budget, ) compact_filename = ( f"{timestamp_string}_{fallback_family_token}_{fallback_run_token}_{report_suffix}" ) compact_report_path = report_root / compact_filename assert len(compact_filename) <= maximum_filename_length, ( "Safe validation report filename remains too long | " f"{compact_filename}" ) assert len(str(compact_report_path)) <= SAFE_REPORT_MAX_PATH_LENGTH, ( "Safe validation report path remains too long | " f"{compact_report_path}" ) return compact_filename
[docs] def resolve_experiment_identity(training_config: dict[str, Any]) -> ExperimentIdentity: """Resolve the logical experiment identity from the training config. Args: training_config: Parsed training configuration dictionary. Returns: ExperimentIdentity: Normalized model family, model type, and run name. """ experiment_config = training_config["experiment"] # Extract and Validate Experiment Identity Components model_type = str(experiment_config["model_type"]).strip() assert model_type, "Experiment model_type must not be empty" # Extract model_family with Fallback to model_type if not Explicitly Provided model_family = str(experiment_config.get("model_family", model_type)).strip().lower() assert model_family, "Experiment model_family must not be empty" # Extract and Validate run_name run_name = str(experiment_config["run_name"]).strip() assert run_name, "Experiment run_name must not be empty" return ExperimentIdentity( model_family=model_family, model_type=model_type.lower(), run_name=run_name, )
def resolve_runtime_config(training_config: dict[str, Any]) -> dict[str, object]: """Resolve runtime overrides merged with repository defaults.""" # Start with Default Runtime Config and Update with Training Config Overrides runtime_config = dict(DEFAULT_RUNTIME_CONFIG_DICTIONARY) raw_runtime_config = training_config.get("runtime", {}) if isinstance(raw_runtime_config, dict): runtime_config.update(raw_runtime_config) if bool(training_config["training"]["deterministic"]): runtime_config["benchmark"] = False return runtime_config def build_run_name(training_config: dict[str, Any], run_name_suffix: str | None = None) -> str: """ Build Run Name """ # Construct Run Name with Optional Suffix for Uniqueness if Needed experiment_identity = resolve_experiment_identity(training_config) if not run_name_suffix: return experiment_identity.run_name return f"{experiment_identity.run_name}_{run_name_suffix}" def build_run_instance_id(run_name: str) -> str: """ Build Run Instance Id """ timestamp_string = datetime.now().strftime(TRAINING_RUN_TIMESTAMP_FORMAT) return f"{timestamp_string}__{sanitize_name(run_name)}"
[docs] def prepare_output_artifact_training_config( training_config: dict[str, Any], artifact_kind: str = RUN_OUTPUT_ARTIFACT_KIND, run_name_suffix: str | None = None, run_instance_id: str | None = None, ) -> dict[str, Any]: """Attach output-artifact metadata to a cloned training configuration. Args: training_config: Source training configuration. artifact_kind: Artifact family such as training run or validation check. run_name_suffix: Optional suffix appended to the logical run name. run_instance_id: Optional explicit immutable run instance identifier. Returns: dict[str, Any]: Cloned training configuration enriched with artifact metadata under the `metadata` section. """ # Clone the Training Config prepared_training_config = clone_training_config(training_config) prepared_run_name = build_run_name(prepared_training_config, run_name_suffix) metadata_dictionary = prepared_training_config.setdefault("metadata", {}) existing_artifact_kind = str(metadata_dictionary.get("output_artifact_kind", "")).strip() existing_output_run_name = str(metadata_dictionary.get("output_run_name", "")).strip() existing_run_instance_id = str(metadata_dictionary.get("run_instance_id", "")).strip() # Preserve Existing Artifact Identity When the Config is Already Prepared preserve_existing_identity = ( run_instance_id is None and run_name_suffix is None and existing_run_instance_id not in ["", None] and existing_output_run_name == prepared_run_name and existing_artifact_kind in ["", artifact_kind] ) resolved_run_instance_id = existing_run_instance_id if preserve_existing_identity else (run_instance_id or build_run_instance_id(prepared_run_name)) # Persist Output Artifact Identity Inside Training Metadata metadata_dictionary["output_artifact_kind"] = artifact_kind metadata_dictionary["output_run_name"] = prepared_run_name metadata_dictionary["run_instance_id"] = resolved_run_instance_id return prepared_training_config
def resolve_output_artifact_kind(training_config: dict[str, Any]) -> str: """ Resolve Output Artifact Kind """ # Check Training Metadata for Explicit Output Artifact Kind metadata_dictionary = training_config.get("metadata", {}) if isinstance(metadata_dictionary, dict): output_artifact_kind = str(metadata_dictionary.get("output_artifact_kind", RUN_OUTPUT_ARTIFACT_KIND)).strip() if output_artifact_kind: return output_artifact_kind return RUN_OUTPUT_ARTIFACT_KIND def resolve_output_run_name(training_config: dict[str, Any]) -> str: """ Resolve Output Run Name """ # Check Training Metadata for Explicit Output Run Name metadata_dictionary = training_config.get("metadata", {}) if isinstance(metadata_dictionary, dict): output_run_name = str(metadata_dictionary.get("output_run_name", "")).strip() if output_run_name: return output_run_name return build_run_name(training_config) def resolve_run_instance_id(training_config: dict[str, Any]) -> str: """ Resolve Run Instance Id """ # Check Training Metadata for Explicit run_instance_id metadata_dictionary = training_config.get("metadata", {}) assert isinstance(metadata_dictionary, dict), "Training metadata dictionary is required to resolve run_instance_id" # If run_instance_id is Present and Non-Empty in Metadata run_instance_id = str(metadata_dictionary.get("run_instance_id", "")).strip() assert run_instance_id, "Training metadata must contain a non-empty run_instance_id" return run_instance_id def resolve_output_root(training_config: dict[str, Any]) -> Path: """ Resolve Output Root """ # Determine Output Root Directory Based on Output Artifact Kind and Experiment Identity experiment_identity = resolve_experiment_identity(training_config) output_artifact_kind = resolve_output_artifact_kind(training_config) # Resolve Standard Training-Run Output Root if output_artifact_kind == RUN_OUTPUT_ARTIFACT_KIND: return resolve_runtime_project_relative_path(training_config["paths"]["output_root"]) # Resolve Validation Output Root if output_artifact_kind == VALIDATION_OUTPUT_ARTIFACT_KIND: return resolve_runtime_project_path() / "output" / "validation_checks" / experiment_identity.model_family # Smoke Test Outputs are Organized Under a Separate Root Directory if output_artifact_kind == SMOKE_TEST_OUTPUT_ARTIFACT_KIND: return resolve_runtime_project_path() / "output" / "smoke_tests" / experiment_identity.model_family raise ValueError(f"Unsupported output_artifact_kind | {output_artifact_kind}") def resolve_output_directory(training_config: dict[str, Any]) -> Path: """Resolve the immutable output directory for the prepared artifact.""" # Construct Output Directory Path Based on Prepared Training Metadata output_root = resolve_output_root(training_config) run_instance_id = resolve_run_instance_id(training_config) return output_root / run_instance_id
[docs] def resolve_run_artifact_identity(training_config: dict[str, Any]) -> RunArtifactIdentity: """Resolve the full physical artifact identity for one prepared config.""" # Resolve Experiment Identity experiment_identity = resolve_experiment_identity(training_config) output_directory = resolve_output_directory(training_config) return RunArtifactIdentity( artifact_kind=resolve_output_artifact_kind(training_config), model_family=experiment_identity.model_family, run_name=resolve_output_run_name(training_config), run_instance_id=resolve_run_instance_id(training_config), output_directory=output_directory, )
def resolve_training_variant_details(training_config: dict[str, Any]) -> dict[str, Any]: """Resolve directional-variant metadata for one training configuration.""" metadata_dictionary = training_config.get("metadata", {}) experiment_identity = resolve_experiment_identity(training_config) explicit_base_model_family = "" explicit_training_variant = "" explicit_use_forward_direction = None explicit_use_backward_direction = None if isinstance(metadata_dictionary, dict): explicit_base_model_family = str(metadata_dictionary.get("base_model_family", "")).strip().lower() explicit_training_variant = str(metadata_dictionary.get("training_variant", "")).strip() if metadata_dictionary.get("use_forward_direction") not in [None, ""]: explicit_use_forward_direction = bool(metadata_dictionary.get("use_forward_direction")) if metadata_dictionary.get("use_backward_direction") not in [None, ""]: explicit_use_backward_direction = bool(metadata_dictionary.get("use_backward_direction")) use_forward_direction = explicit_use_forward_direction use_backward_direction = explicit_use_backward_direction # Fall Back to the Dataset Config When Direction Flags Were Not Written Explicitly if use_forward_direction is None or use_backward_direction is None: dataset_config_path = resolve_runtime_project_relative_path( training_config["paths"]["dataset_config_path"] ) with dataset_config_path.open("r", encoding="utf-8") as dataset_config_file: dataset_config_dictionary = yaml.safe_load(dataset_config_file) assert isinstance(dataset_config_dictionary, dict), ( f"Dataset config must contain a dictionary | {dataset_config_path}" ) direction_dictionary = dataset_config_dictionary.get("directions", {}) assert isinstance(direction_dictionary, dict), ( f"Dataset config directions must contain a dictionary | {dataset_config_path}" ) if use_forward_direction is None: use_forward_direction = bool(direction_dictionary.get("use_forward_direction", True)) if use_backward_direction is None: use_backward_direction = bool(direction_dictionary.get("use_backward_direction", True)) training_variant = explicit_training_variant if training_variant == "": if use_forward_direction and use_backward_direction: training_variant = GLOBAL_TRAINING_VARIANT elif use_forward_direction and not use_backward_direction: training_variant = FORWARD_ONLY_TRAINING_VARIANT elif use_backward_direction and not use_forward_direction: training_variant = BACKWARD_ONLY_TRAINING_VARIANT else: training_variant = "custom" if use_forward_direction and use_backward_direction: direction_scope_label = "bidirectional" elif use_forward_direction and not use_backward_direction: direction_scope_label = "forward_only" elif use_backward_direction and not use_forward_direction: direction_scope_label = "backward_only" else: direction_scope_label = "custom" return { "base_model_family": explicit_base_model_family or experiment_identity.model_family, "training_variant": training_variant, "direction_scope_label": direction_scope_label, "use_forward_direction": bool(use_forward_direction), "use_backward_direction": bool(use_backward_direction), } def load_run_metadata_snapshot(output_directory: Path) -> dict[str, Any] | None: """Load one run-metadata snapshot when it exists. Args: output_directory: Candidate immutable artifact directory. Returns: dict[str, Any] | None: Parsed run metadata dictionary, or `None` when the snapshot is absent or malformed. """ # Resolve Run-Metadata Path run_metadata_path = output_directory / COMMON_RUN_METADATA_FILENAME if not run_metadata_path.exists(): return None # Load and Validate Run Metadata with run_metadata_path.open("r", encoding="utf-8") as metadata_file: run_metadata_dictionary = yaml.safe_load(metadata_file) if not isinstance(run_metadata_dictionary, dict): return None return run_metadata_dictionary def find_run_output_directory( model_family: str, run_name: str | None = None, run_instance_id: str | None = None, ) -> Path | None: """Find one immutable run output directory from canonical run metadata. Args: model_family: Model-family folder under `output/training_runs/`. run_name: Optional logical run name to match. run_instance_id: Optional immutable run instance identifier to match. Returns: Path | None: Matching artifact directory, or `None` when no canonical match can be recovered. """ # Resolve Candidate Family Root family_output_root = OUTPUT_PATH / "training_runs" / str(model_family).strip().lower() if not family_output_root.exists(): return None normalized_run_name = str(run_name).strip() if run_name not in [None, ""] else "" normalized_run_instance_id = str(run_instance_id).strip() if run_instance_id not in [None, ""] else "" run_instance_match_path: Path | None = None run_name_match_path_list: list[Path] = [] # Scan Immutable Run Directories for Matching Metadata for candidate_output_directory in sorted([path for path in family_output_root.iterdir() if path.is_dir()]): run_metadata_dictionary = load_run_metadata_snapshot(candidate_output_directory) if run_metadata_dictionary is None: continue metadata_run_instance_id = str(run_metadata_dictionary.get("run_instance_id", "")).strip() metadata_run_name = str(run_metadata_dictionary.get("run_name", "")).strip() if normalized_run_instance_id and metadata_run_instance_id == normalized_run_instance_id: run_instance_match_path = candidate_output_directory.resolve() break if normalized_run_name and metadata_run_name == normalized_run_name: run_name_match_path_list.append(candidate_output_directory.resolve()) if run_instance_match_path is not None: return run_instance_match_path if len(run_name_match_path_list) > 0: return sorted(run_name_match_path_list)[-1] return None def summarize_model_parameters(regression_backbone: nn.Module) -> ModelParameterSummary: """Count trainable, frozen, and total parameters for one backbone.""" # Calculate Trainable, Frozen, and Total Parameter Counts trainable_parameter_count = sum(parameter.numel() for parameter in regression_backbone.parameters() if parameter.requires_grad) total_parameter_count = sum(parameter.numel() for parameter in regression_backbone.parameters()) frozen_parameter_count = total_parameter_count - trainable_parameter_count # Build and Return ModelParameterSummary Dataclass Instance return ModelParameterSummary( backbone_name=regression_backbone.__class__.__name__, trainable_parameter_count=int(trainable_parameter_count), frozen_parameter_count=int(frozen_parameter_count), total_parameter_count=int(total_parameter_count), ) def parse_non_negative_integer(value: Any, value_name: str) -> int: """Parse one non-negative integer configuration value.""" try: parsed_value = int(value) except (TypeError, ValueError) as parse_error: raise ValueError(f"{value_name} must be a non-negative integer | {value}") from parse_error if parsed_value < 0: raise ValueError(f"{value_name} must be non-negative | {parsed_value}") return parsed_value def resolve_dataloader_num_workers(raw_num_workers: Any) -> int: """Resolve explicit or automatic PyTorch DataLoader worker count.""" # Preserve Explicit Integer Settings if raw_num_workers is None: raw_num_workers = 0 if not isinstance(raw_num_workers, str): return parse_non_negative_integer(raw_num_workers, "dataset.num_workers") normalized_num_workers = raw_num_workers.strip().lower() if normalized_num_workers != "auto": return parse_non_negative_integer(normalized_num_workers, "dataset.num_workers") # Prefer Operator-Provided Runtime Override environment_worker_value = os.environ.get(ENVIRONMENT_DATALOADER_WORKERS_KEY) if environment_worker_value not in [None, ""]: return parse_non_negative_integer(environment_worker_value, ENVIRONMENT_DATALOADER_WORKERS_KEY) # Resolve Capped Automatic Worker Count environment_cap_value = os.environ.get(ENVIRONMENT_DATALOADER_WORKER_CAP_KEY) worker_cap = ( parse_non_negative_integer(environment_cap_value, ENVIRONMENT_DATALOADER_WORKER_CAP_KEY) if environment_cap_value not in [None, ""] else DEFAULT_DATALOADER_WORKER_CAP ) if worker_cap <= 0: return 0 logical_cpu_count = os.cpu_count() or 1 cpu_based_worker_count = max(logical_cpu_count - 1, 0) if cpu_based_worker_count <= 0: return 0 return max(1, min(cpu_based_worker_count, worker_cap)) def parse_boolean_config_value(value: Any, value_name: str) -> bool: """Parse one boolean configuration value from YAML or CLI-like strings.""" if isinstance(value, bool): return value if isinstance(value, str): normalized_value = value.strip().lower() if normalized_value in ["true", "1", "yes", "y", "on"]: return True if normalized_value in ["false", "0", "no", "n", "off"]: return False if value in [0, 1]: return bool(value) raise ValueError(f"{value_name} must be a boolean value | {value}")
[docs] def create_datamodule_from_training_config(training_config: dict[str, Any]) -> TransmissionErrorDataModule: """Instantiate the TE LightningDataModule from the training config.""" # Resolve Runtime Config for DataModule Creation runtime_config = resolve_runtime_config(training_config) # Create and Return TransmissionErrorDataModule return TransmissionErrorDataModule( dataset_config_path=training_config["paths"]["dataset_config_path"], dataset_name=training_config.get("dataset", {}).get("name"), curve_batch_size=int(training_config["dataset"]["curve_batch_size"]), point_stride=int(training_config["dataset"]["point_stride"]), maximum_points_per_curve=training_config["dataset"]["maximum_points_per_curve"], collate_mode=str(training_config["dataset"].get("collate_mode", "point")), sequence_length=int(training_config["dataset"].get("sequence_length", 17)), sequence_stride=int(training_config["dataset"].get("sequence_stride", 1)), sequence_target_position=str(training_config["dataset"].get("sequence_target_position", "center")), maximum_sequences_per_curve=training_config["dataset"].get("maximum_sequences_per_curve"), shuffle_training_batch_elements=bool(training_config["dataset"].get("shuffle_training_batch_elements", True)), num_workers=resolve_dataloader_num_workers(training_config["dataset"].get("num_workers", 0)), pin_memory=parse_boolean_config_value(training_config["dataset"].get("pin_memory", False), "dataset.pin_memory"), use_non_blocking_transfer=bool(runtime_config["use_non_blocking_transfer"]), )
def resolve_model_configuration_for_input_dim(training_config: dict[str, Any], input_feature_dim: int) -> dict[str, Any]: """Return a model configuration with a concrete dataset input dimension. Args: training_config: Parsed training configuration dictionary. input_feature_dim: Input dimension resolved from the prepared dataset. Returns: dict[str, Any]: Copy of the model configuration with `input_size` resolved to the dataset feature dimension. """ # Resolve Automatic Or Explicit Input Size model_configuration = deepcopy(training_config["model"]) configured_input_size = model_configuration.get("input_size", "auto") if configured_input_size in [None, "auto"]: model_configuration["input_size"] = input_feature_dim else: configured_input_size = int(configured_input_size) assert configured_input_size == input_feature_dim, ( f"Configured Input Size and Dataset Input Feature Dim mismatch | {configured_input_size} vs {input_feature_dim}" ) return model_configuration
[docs] def create_regression_backbone_from_training_config(training_config: dict[str, Any], input_feature_dim: int) -> nn.Module: """Instantiate the regression backbone declared in the config. Args: training_config: Parsed training configuration dictionary. input_feature_dim: Input dimension resolved from the prepared dataset. Returns: nn.Module: Configured regression backbone. """ # Resolve Automatic Or Explicit Input Size model_configuration = resolve_model_configuration_for_input_dim(training_config, input_feature_dim) # Create and Return Regression Backbone Model Based on Training Config return create_model( model_type=str(training_config["experiment"]["model_type"]), model_configuration=model_configuration, )
[docs] def create_regression_module_from_training_config( training_config: dict[str, Any], regression_backbone: nn.Module, input_feature_dim: int, target_feature_dim: int, normalization_statistics: NormalizationStatistics, ) -> TransmissionErrorRegressionModule: """Wrap a configured backbone in the shared Lightning regression module.""" # Build Regression Module With Optimization Hyperparameters return TransmissionErrorRegressionModule( regression_model=regression_backbone, input_feature_dim=input_feature_dim, target_feature_dim=target_feature_dim, learning_rate=float(training_config["training"]["learning_rate"]), weight_decay=float(training_config["training"]["weight_decay"]), normalization_statistics=normalization_statistics, loss_configuration=training_config["training"].get("loss", {}), )
[docs] def initialize_training_components( training_config: dict[str, Any], ) -> tuple[TransmissionErrorDataModule, nn.Module, TransmissionErrorRegressionModule, NormalizationStatistics]: """Build the datamodule, backbone, module, and normalization bundle. Args: training_config: Parsed training configuration dictionary. Returns: tuple[TransmissionErrorDataModule, nn.Module, TransmissionErrorRegressionModule, NormalizationStatistics]: Fully initialized training components ready for fit or validation work. """ # Create DataModule and Setup to Access Dataset Statistics datamodule = create_datamodule_from_training_config(training_config) datamodule.setup(stage="fit") # Create Regression Backbone and Regression Module Based on Training Config and Dataset Statistics input_feature_dim = datamodule.get_input_feature_dim() target_feature_dim = datamodule.get_target_feature_dim() normalization_statistics = datamodule.get_normalization_statistics() # Build Regression Backbone From The Resolved Dataset Feature Dimensions regression_backbone = create_regression_backbone_from_training_config( training_config, input_feature_dim, ) # Build Regression Module With Shared Normalization Statistics regression_module = create_regression_module_from_training_config( training_config, regression_backbone, input_feature_dim, target_feature_dim, normalization_statistics, ) return datamodule, regression_backbone, regression_module, normalization_statistics
[docs] def fetch_first_batch(datamodule: TransmissionErrorDataModule, split_name: str = "train") -> dict[str, Any]: """Fetch the first batch from one requested dataloader split.""" # Fetch First Batch from Specified Data Split and Return as Dictionary if split_name == "train": dataloader = datamodule.train_dataloader() elif split_name == "validation": dataloader = datamodule.val_dataloader() elif split_name == "test": dataloader = datamodule.test_dataloader() else: raise ValueError(f"Unsupported split_name | {split_name}") return next(iter(dataloader))
[docs] def validate_batch_dictionary(batch_dictionary: dict[str, Any], input_feature_dim: int, target_feature_dim: int) -> dict[str, Any]: """Validate the structural contract of a collated point or sequence batch. Args: batch_dictionary: Batch emitted by the datamodule collate function. input_feature_dim: Expected final input feature dimension. target_feature_dim: Expected final target feature dimension. Returns: dict[str, Any]: Small structural summary of the validated batch. """ input_tensor = batch_dictionary["input_tensor"] target_tensor = batch_dictionary["target_tensor"] # Validate Batch Dictionary Contains Required Tensors with Correct Shapes and Types assert isinstance(input_tensor, torch.Tensor), "input_tensor must be a torch.Tensor" assert isinstance(target_tensor, torch.Tensor), "target_tensor must be a torch.Tensor" assert input_tensor.ndim in [2, 3], f"input_tensor must be rank-2 or rank-3 | {tuple(input_tensor.shape)}" assert target_tensor.ndim == 2, f"target_tensor must be rank-2 | {tuple(target_tensor.shape)}" assert input_tensor.shape[-1] == input_feature_dim, (f"Input feature mismatch | {input_tensor.shape[-1]} vs {input_feature_dim}") assert target_tensor.shape[-1] == target_feature_dim, (f"Target feature mismatch | {target_tensor.shape[-1]} vs {target_feature_dim}") assert input_tensor.shape[0] == target_tensor.shape[0], ( f"Input and Target batch dimensions must match | {input_tensor.shape[0]} vs {target_tensor.shape[0]}" ) # Resolve Batch Mode for Summary batch_mode = "sequence" if input_tensor.ndim == 3 else "point" return { "batch_mode": batch_mode, "point_batch_size": int(input_tensor.shape[0]) if batch_mode == "point" else 0, "sequence_batch_size": int(input_tensor.shape[0]) if batch_mode == "sequence" else 0, "sequence_length": int(input_tensor.shape[1]) if batch_mode == "sequence" else 0, "input_feature_dim": int(input_tensor.shape[-1]), "target_feature_dim": int(target_tensor.shape[-1]), "curve_count": int(batch_dictionary.get("curve_count", 0)), }
def serialize_metric_dictionary(metric_dictionary: dict[str, object]) -> dict[str, object]: """ Serialize Metric Dictionary """ serialized_metric_dictionary: dict[str, object] = {} for metric_name, metric_value in metric_dictionary.items(): # Serialize Metric Values to Ensure They are JSON/YAML Serializable Types if isinstance(metric_value, torch.Tensor): serialized_metric_dictionary[metric_name] = float(metric_value.detach().cpu().item()) continue # Handle Common Numeric Types Directly if isinstance(metric_value, float): serialized_metric_dictionary[metric_name] = float(metric_value) continue # Handle Integer Types Directly if isinstance(metric_value, int): serialized_metric_dictionary[metric_name] = int(metric_value) continue # Fallback to String Representation for Unsupported Types serialized_metric_dictionary[metric_name] = str(metric_value) return serialized_metric_dictionary def build_comparison_payload( experiment_identity: ExperimentIdentity, parameter_summary: ModelParameterSummary, validation_metric_dictionary: dict[str, object], test_metric_dictionary: dict[str, object], ) -> dict[str, object]: """ Build Comparison Payload """ return { "model_family": experiment_identity.model_family, "model_type": experiment_identity.model_type, "run_name": experiment_identity.run_name, "backbone_name": parameter_summary.backbone_name, "trainable_parameter_count": parameter_summary.trainable_parameter_count, "total_parameter_count": parameter_summary.total_parameter_count, "val_mae": validation_metric_dictionary.get("val_mae"), "val_rmse": validation_metric_dictionary.get("val_rmse"), "test_mae": test_metric_dictionary.get("test_mae"), "test_rmse": test_metric_dictionary.get("test_rmse"), "deployment_notes": "", "interpretability_notes": "", }
[docs] def build_common_metrics_snapshot( training_config: dict[str, Any], config_path: str | Path, output_directory: Path, datamodule: TransmissionErrorDataModule, parameter_summary: ModelParameterSummary, runtime_config: dict[str, object], best_model_path: str, validation_metric_list: list[dict[str, object]], test_metric_list: list[dict[str, object]], ) -> dict[str, object]: """Build the canonical metrics snapshot stored with a training artifact.""" # Resolve Experiment Identity and Dataset Split Summary for Snapshot experiment_identity = resolve_experiment_identity(training_config) run_artifact_identity = resolve_run_artifact_identity(training_config) training_variant_details = resolve_training_variant_details(training_config) dataset_split_summary = datamodule.get_dataset_split_summary() normalization_statistics = datamodule.get_normalization_statistics() # Serialize Validation and Test Metric Dictionaries for Snapshot validation_metric_dictionary = serialize_metric_dictionary(validation_metric_list[0] if len(validation_metric_list) > 0 else {}) test_metric_dictionary = serialize_metric_dictionary(test_metric_list[0] if len(test_metric_list) > 0 else {}) return { "schema_version": 1, "config_path": str(resolve_project_relative_path(config_path)), "experiment": { **asdict(experiment_identity), "output_run_name": run_artifact_identity.run_name, "run_instance_id": run_artifact_identity.run_instance_id, "output_artifact_kind": run_artifact_identity.artifact_kind, **training_variant_details, }, "artifacts": { "output_directory": str(output_directory), "best_checkpoint_path": best_model_path, }, "dataset_split": asdict(dataset_split_summary), "dataset": { "dataset_id": dataset_split_summary.dataset_name, "dataset_schema": dataset_split_summary.dataset_schema, "input_feature_names": list(dataset_split_summary.input_feature_name_list), "target_feature_names": list(dataset_split_summary.target_feature_name_list), "input_feature_dim": dataset_split_summary.input_feature_dim, "target_feature_dim": dataset_split_summary.target_feature_dim, }, "model_summary": asdict(parameter_summary), "runtime_config": {key: str(value) if isinstance(value, Path) else value for key, value in runtime_config.items()}, "normalization_statistics": { "input_feature_mean": [float(value) for value in normalization_statistics.input_feature_mean.tolist()], "input_feature_std": [float(value) for value in normalization_statistics.input_feature_std.tolist()], "target_mean": [float(value) for value in normalization_statistics.target_mean.tolist()], "target_std": [float(value) for value in normalization_statistics.target_std.tolist()], }, "validation_metrics": validation_metric_dictionary, "test_metrics": test_metric_dictionary, "comparison_payload": build_comparison_payload( experiment_identity, parameter_summary, validation_metric_dictionary, test_metric_dictionary, ), }
def format_project_relative_path(path_value: str | Path | None) -> str: """ Format Project Relative Path """ if path_value is None: return "N/A" # Handle Special Case for Checkpoint Paths When Best Checkpoint is Not Available if isinstance(path_value, str) and "Best checkpoint not available" in path_value: return path_value resolved_path = Path(path_value) # Format Path as Project-Relative if Possible, Otherwise Return Absolute Path for candidate_project_root in [resolve_runtime_project_path(), PROJECT_PATH]: try: return repository_path_support.format_repository_relative_path( resolved_path, candidate_project_root, ) except ValueError: continue return resolved_path.as_posix() def load_yaml_snapshot(input_path: Path) -> dict[str, Any]: """Load a YAML snapshot and validate that it is dictionary-shaped.""" # Load the YAML Snapshot from the Specified Input Path with input_path.open("r", encoding="utf-8") as input_file: snapshot_dictionary = yaml.safe_load(input_file) # Validate that the Loaded Snapshot is a Dictionary assert isinstance(snapshot_dictionary, dict), f"YAML snapshot must contain a dictionary | {input_path}" return snapshot_dictionary
[docs] def save_yaml_snapshot(snapshot_dictionary: dict[str, Any], output_path: Path) -> None: """Persist one YAML snapshot to disk, creating parent folders as needed.""" output_path.parent.mkdir(parents=True, exist_ok=True) # Save the Snapshot Dictionary to the Specified Output Path in YAML Format with output_path.open("w", encoding="utf-8") as output_file: yaml.safe_dump(snapshot_dictionary, output_file, sort_keys=False)
[docs] def save_training_config_snapshot(training_config: dict[str, Any], output_directory: Path) -> None: """Persist the effective training configuration inside an artifact folder.""" output_directory.mkdir(parents=True, exist_ok=True) # Save the Training Config Snapshot to the Output Directory save_yaml_snapshot(training_config, output_directory / COMMON_TRAINING_CONFIG_FILENAME)
[docs] def save_common_metrics_snapshot(metrics_snapshot_dictionary: dict[str, Any], output_directory: Path) -> None: """Persist the common metrics snapshot inside an artifact folder.""" # Save the Common Metrics Snapshot to the Output Directory save_yaml_snapshot(metrics_snapshot_dictionary, output_directory / COMMON_METRICS_FILENAME)
def build_registry_entry(metrics_snapshot_dictionary: dict[str, Any]) -> dict[str, Any]: """Convert a metrics snapshot into a comparable registry entry.""" # Extract Relevant Information from the Common Metrics Snapshot to Build a Registry Entry experiment_dictionary = metrics_snapshot_dictionary["experiment"] comparison_payload = metrics_snapshot_dictionary["comparison_payload"] model_summary_dictionary = metrics_snapshot_dictionary["model_summary"] artifacts_dictionary = metrics_snapshot_dictionary["artifacts"] dataset_dictionary = metrics_snapshot_dictionary.get("dataset", {}) dataset_split_dictionary = metrics_snapshot_dictionary.get("dataset_split", {}) if not dataset_dictionary and isinstance(dataset_split_dictionary, dict): dataset_dictionary = { "dataset_id": dataset_split_dictionary.get("dataset_name"), "dataset_schema": dataset_split_dictionary.get("dataset_schema"), "input_feature_names": dataset_split_dictionary.get("input_feature_name_list", []), "target_feature_names": dataset_split_dictionary.get("target_feature_name_list", []), "input_feature_dim": dataset_split_dictionary.get("input_feature_dim"), "target_feature_dim": dataset_split_dictionary.get("target_feature_dim"), } return { "run_instance_id": experiment_dictionary["run_instance_id"], "run_name": experiment_dictionary["run_name"], "output_run_name": experiment_dictionary["output_run_name"], "output_artifact_kind": experiment_dictionary["output_artifact_kind"], "model_family": comparison_payload["model_family"], "base_model_family": experiment_dictionary.get("base_model_family", comparison_payload["model_family"]), "model_type": comparison_payload["model_type"], "dataset_id": dataset_dictionary.get("dataset_id", "simplified_dataset"), "dataset_schema": dataset_dictionary.get("dataset_schema", "simplified_curve_v1"), "input_feature_names": dataset_dictionary.get("input_feature_names", []), "target_feature_names": dataset_dictionary.get("target_feature_names", []), "input_feature_dim": dataset_dictionary.get("input_feature_dim"), "training_variant": experiment_dictionary.get("training_variant", GLOBAL_TRAINING_VARIANT), "direction_scope_label": experiment_dictionary.get("direction_scope_label", "bidirectional"), "use_forward_direction": bool(experiment_dictionary.get("use_forward_direction", True)), "use_backward_direction": bool(experiment_dictionary.get("use_backward_direction", True)), "trainable_parameter_count": model_summary_dictionary["trainable_parameter_count"], "total_parameter_count": model_summary_dictionary["total_parameter_count"], "val_mae": comparison_payload.get("val_mae"), "val_rmse": comparison_payload.get("val_rmse"), "test_mae": comparison_payload.get("test_mae"), "test_rmse": comparison_payload.get("test_rmse"), "output_directory": format_project_relative_path(artifacts_dictionary.get("output_directory")), "best_checkpoint_path": format_project_relative_path(artifacts_dictionary.get("best_checkpoint_path")), "metrics_path": ( f"{format_project_relative_path(artifacts_dictionary.get('output_directory'))}/{COMMON_METRICS_FILENAME}" if artifacts_dictionary.get("output_directory") not in [None, ""] else "N/A" ), "report_path": ( f"{format_project_relative_path(artifacts_dictionary.get('output_directory'))}/{COMMON_RUN_REPORT_FILENAME}" if artifacts_dictionary.get("output_directory") not in [None, ""] else "N/A" ), "selection_policy": dict(SELECTION_POLICY_DICTIONARY), "selected_at": datetime.now().isoformat(timespec="seconds"), } def resolve_selection_value(metric_value: object) -> float: """ Resolve Selection Value """ # Convert the Metric Value to a Float if it is a Numeric Type if isinstance(metric_value, (int, float)): return float(metric_value) return float("inf") def build_selection_key(registry_entry: dict[str, Any]) -> tuple[float, float, float, float, str]: """ Build Selection Key """ # Build a Selection Key Tuple Based on the Registry Entry's Metrics return ( resolve_selection_value(registry_entry.get("test_mae")), resolve_selection_value(registry_entry.get("test_rmse")), resolve_selection_value(registry_entry.get("val_mae")), resolve_selection_value(registry_entry.get("trainable_parameter_count")), str(registry_entry.get("run_instance_id", "")), ) def sort_registry_entries(registry_entry_list: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Sort Registry Entries """ # Sort the Registry Entry List Based on the Selection Key Built from Each Entry's Metrics return sorted(registry_entry_list, key=build_selection_key) def load_registry_entry_list(leaderboard_path: Path) -> list[dict[str, Any]]: """ Load Registry Entry List """ if not leaderboard_path.exists(): return [] # Load the Leaderboard YAML Snapshot and Extract the Registry Entry List, Validating its Structure leaderboard_dictionary = load_yaml_snapshot(leaderboard_path) registry_entry_list = leaderboard_dictionary.get("entry_list", []) assert isinstance(registry_entry_list, list), f"Registry entry_list must be a list | {leaderboard_path}" return [registry_entry for registry_entry in registry_entry_list if isinstance(registry_entry, dict)] def build_family_registry_directory(model_family: str) -> Path: """ Build Family Registry Directory """ # Construct the Family Registry Directory Path Based on the Model Family Name return (FAMILY_REGISTRY_OUTPUT_ROOT / model_family).resolve()
[docs] def update_family_registry(metrics_snapshot_dictionary: dict[str, Any]) -> dict[str, Any]: """Update the family leaderboard and latest-family-best snapshots. Args: metrics_snapshot_dictionary: Common metrics snapshot for one completed training artifact. Returns: dict[str, Any]: Selected best entry for the model family after update. """ # Build a Registry Entry from the Common Metrics Snapshot and Update the Family Registry Leaderboard registry_entry = build_registry_entry(metrics_snapshot_dictionary) family_registry_directory = build_family_registry_directory(registry_entry["model_family"]) leaderboard_path = family_registry_directory / FAMILY_LEADERBOARD_FILENAME best_entry_path = family_registry_directory / FAMILY_BEST_FILENAME # Load Existing Registry Entries, Filter Out Any Entry with the Same run_instance_id existing_registry_entry_list = load_registry_entry_list(leaderboard_path) filtered_registry_entry_list = [ existing_registry_entry for existing_registry_entry in existing_registry_entry_list if existing_registry_entry.get("run_instance_id") != registry_entry["run_instance_id"] ] filtered_registry_entry_list.append(registry_entry) sorted_registry_entry_list = sort_registry_entries(filtered_registry_entry_list) best_registry_entry = sorted_registry_entry_list[0] # Build the Leaderboard and Best Entry Dictionaries with Updated Information and Save Them as YAML Snapshots leaderboard_dictionary = { "schema_version": 1, "model_family": registry_entry["model_family"], "selection_policy": dict(SELECTION_POLICY_DICTIONARY), "updated_at": datetime.now().isoformat(timespec="seconds"), "entry_count": len(sorted_registry_entry_list), "entry_list": sorted_registry_entry_list, } # Best Entry Dictionary best_entry_dictionary = { "schema_version": 1, "model_family": registry_entry["model_family"], "selection_policy": dict(SELECTION_POLICY_DICTIONARY), "updated_at": datetime.now().isoformat(timespec="seconds"), "best_entry": best_registry_entry, } # Save the Updated Leaderboard and Best Entry as YAML Snapshots in the Family Registry Directory save_yaml_snapshot(leaderboard_dictionary, leaderboard_path) save_yaml_snapshot(best_entry_dictionary, best_entry_path) return best_registry_entry
[docs] def update_program_registry(best_registry_entry: dict[str, Any]) -> dict[str, Any]: """Update the program-wide best-solution registry entry.""" program_best_path = PROGRAM_REGISTRY_OUTPUT_ROOT / PROGRAM_BEST_FILENAME current_best_entry = None # If a Current Best Entry Exists in the Program Registry, Load it for Comparison if program_best_path.exists(): program_best_dictionary = load_yaml_snapshot(program_best_path) loaded_best_entry = program_best_dictionary.get("best_entry") if isinstance(loaded_best_entry, dict): current_best_entry = loaded_best_entry # Compare the New Best Registry Entry with the Current Best Entry selected_best_entry = best_registry_entry if isinstance(current_best_entry, dict) and build_selection_key(current_best_entry) <= build_selection_key(best_registry_entry): selected_best_entry = current_best_entry # Build the Program Best Dictionary with the Selected Best Entry program_best_dictionary = { "schema_version": 1, "selection_policy": dict(SELECTION_POLICY_DICTIONARY), "updated_at": datetime.now().isoformat(timespec="seconds"), "best_entry": selected_best_entry, } # Save the Updated Program Best Dictionary as a YAML Snapshot in the Program Registry Directory save_yaml_snapshot(program_best_dictionary, program_best_path) return selected_best_entry
[docs] def save_run_metadata_snapshot(training_config: dict[str, Any], output_directory: Path) -> None: """Persist the resolved artifact identity inside an output directory.""" # Resolve Run Artifact Identity and Save it as a YAML Snapshot in the Output Directory for Reference and Traceability run_artifact_identity = resolve_run_artifact_identity(training_config) training_variant_details = resolve_training_variant_details(training_config) dataset_schema = resolve_training_dataset_schema(training_config) save_yaml_snapshot( { "schema_version": 1, "artifact_kind": run_artifact_identity.artifact_kind, "model_family": run_artifact_identity.model_family, **training_variant_details, "run_name": run_artifact_identity.run_name, "run_instance_id": run_artifact_identity.run_instance_id, "output_directory": format_project_relative_path(run_artifact_identity.output_directory), "dataset_id": dataset_schema.dataset_name, "dataset_schema": dataset_schema.schema_name, "input_feature_names": list(dataset_schema.input_feature_name_list), "target_feature_names": list(dataset_schema.target_feature_name_list), "input_feature_dim": dataset_schema.input_feature_dim, }, output_directory / COMMON_RUN_METADATA_FILENAME, )