"""Feedforward TE training entry point and terminal-reporting helpers."""
from __future__ import annotations
# Import Python Utilities
import sys, shutil, logging, warnings, argparse
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
# Define Project Path
PROJECT_PATH = Path(__file__).resolve().parents[2]
# Ensure Repository Root Is Available For Direct Script Execution
if str(PROJECT_PATH) not in sys.path: sys.path.insert(0, str(PROJECT_PATH))
# Import PyTorch Lightning Utilities
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
# Import PyTorch Utilities
import torch
# Import Terminal Formatting Utilities
try:
from colorama import Fore
from colorama import Style
from colorama import init as colorama_init
except ImportError:
class _PlainTerminalColor:
BLACK = ""
BLUE = ""
CYAN = ""
GREEN = ""
MAGENTA = ""
RED = ""
RESET = ""
WHITE = ""
YELLOW = ""
BRIGHT = ""
NORMAL = ""
RESET_ALL = ""
Fore = _PlainTerminalColor()
Style = _PlainTerminalColor()
def colorama_init(*args, **kwargs) -> None:
""" Fallback Colorama Init """
return None
# Import Project Models And Training Utilities
from scripts.models.model_factory import create_model
from scripts.training import shared_training_infrastructure
from scripts.training.transmission_error_datamodule import TransmissionErrorDataModule
from scripts.training.transmission_error_regression_module import TransmissionErrorRegressionModule
DEFAULT_CONFIG_PATH = shared_training_infrastructure.DEFAULT_CONFIG_PATH
SECTION_DIVIDER_WIDTH = 96
KEY_LABEL_WIDTH = 34
PROGRESS_BAR_REFRESH_RATE = 10
LIGHTNING_INFO_LOGGER_NAME_LIST = ["lightning.pytorch.utilities.rank_zero", "lightning.fabric.utilities.rank_zero"]
INPUT_FEATURE_NAME_LIST = ["angular_position_deg", "input_speed_rpm", "input_torque_nm", "oil_temperature_deg", "direction_flag"]
TARGET_FEATURE_NAME_LIST = ["transmission_error_deg"]
# Set Torch Matmul Precision
torch.set_float32_matmul_precision("high")
# Initialize Terminal Colors
colorama_init(autoreset=True)
# Suppress Known Lightning Internal Warning
warnings.filterwarnings("ignore", message=r"`isinstance\(treespec, LeafSpec\)` is deprecated.*", category=FutureWarning, module=r"lightning\.pytorch\.utilities\._pytree")
def format_terminal_value(value: object) -> str:
"""Format one value for compact terminal rendering.
Args:
value: Arbitrary runtime value to render for the terminal summary.
Returns:
String representation aligned with the training terminal style.
"""
if isinstance(value, float):
# Format Float With 6 Decimal Places If Not Too Small, Otherwise Use Scientific Notation
if abs(value) >= 1.0e-4:return f"{value:.6f}"
return f"{value:.6e}"
if isinstance(value, Path):
return str(value)
if isinstance(value, (list, tuple)):
# Recursively Format Each Item In The List Or Tuple
formatted_value_list = [format_terminal_value(item) for item in value]
return "[" + ", ".join(formatted_value_list) + "]"
if value is None:
return "None"
return str(value)
def print_section_header(section_title: str) -> None:
""" Print Section Header """
print()
print(Fore.CYAN + Style.BRIGHT + "=" * SECTION_DIVIDER_WIDTH)
print(Fore.CYAN + Style.BRIGHT + section_title)
print(Fore.CYAN + Style.BRIGHT + "=" * SECTION_DIVIDER_WIDTH)
def print_subsection_header(subsection_title: str) -> None:
""" Print Subsection Header """
print()
print(Fore.MAGENTA + Style.BRIGHT + subsection_title)
print(Fore.MAGENTA + "-" * len(subsection_title))
def print_key_value(label: str, value: object, value_color: str = Fore.WHITE) -> None:
""" Print Key Value """
formatted_label = f"{label:<{KEY_LABEL_WIDTH}}"
formatted_value = format_terminal_value(value)
print(f"{Fore.WHITE}{Style.BRIGHT}{formatted_label}{Style.RESET_ALL}{value_color}{formatted_value}{Style.RESET_ALL}")
def print_info_message(message: str) -> None:
""" Print Info Message """
print(f"{Fore.BLUE}{Style.BRIGHT}[INFO]{Style.RESET_ALL} {message}")
def print_success_message(message: str) -> None:
""" Print Success Message """
print(f"{Fore.GREEN}{Style.BRIGHT}[DONE]{Style.RESET_ALL} {message}")
def print_warning_message(message: str) -> None:
""" Print Warning Message """
print(f"{Fore.YELLOW}{Style.BRIGHT}[WARN]{Style.RESET_ALL} {message}")
@contextmanager
def suppress_lightning_info_logs() -> Iterator[None]:
""" Suppress Lightning Info Logs """
# Store Current Logger Levels
logger_state_list: list[tuple[logging.Logger, int]] = []
# Set Lightning Loggers To Warning Level To Suppress Info Logs
for logger_name in LIGHTNING_INFO_LOGGER_NAME_LIST:
lightning_logger = logging.getLogger(logger_name)
logger_state_list.append((lightning_logger, lightning_logger.level))
lightning_logger.setLevel(logging.WARNING)
try:
yield
finally:
# Restore Previous Logger Levels
for lightning_logger, previous_log_level in logger_state_list:
lightning_logger.setLevel(previous_log_level)
def print_feature_statistics(feature_name_list: list[str], mean_value_list: list[float], std_value_list: list[float]) -> None:
""" Print Feature Statistics """
for feature_name, feature_mean, feature_std in zip(feature_name_list, mean_value_list, std_value_list):
print_key_value(label=f"{feature_name} | mean", value=feature_mean, value_color=Fore.YELLOW)
print_key_value(label=f"{feature_name} | std", value=feature_std, value_color=Fore.YELLOW)
def print_training_configuration_summary(training_config: dict) -> None:
"""Print the resolved training configuration in the repository terminal style.
Args:
training_config: Fully resolved training configuration dictionary for
the current run.
"""
# Read Config Sections
path_config = training_config["paths"]
experiment_config = training_config["experiment"]
dataset_config = training_config["dataset"]
model_config = training_config["model"]
optimization_config = training_config["training"]
runtime_config = resolve_runtime_config(training_config)
model_type_display_name = str(experiment_config["model_type"]).replace("_", " ").title()
# Print Config Overview
print_section_header(f"{model_type_display_name} Training Configuration")
print_info_message("Resolved YAML configuration for the current training run")
# Print Path Configuration
print_subsection_header("Paths")
print_key_value("Dataset Config Path", path_config["dataset_config_path"], value_color=Fore.YELLOW)
print_key_value("Output Root", path_config["output_root"], value_color=Fore.YELLOW)
# Print Experiment Configuration
print_subsection_header("Experiment")
print_key_value("Run Name", experiment_config["run_name"], value_color=Fore.YELLOW)
print_key_value("Model Family", experiment_config.get("model_family", experiment_config["model_type"]), value_color=Fore.YELLOW)
print_key_value("Model Type", experiment_config["model_type"], value_color=Fore.YELLOW)
# Print Dataset Configuration
print_subsection_header("Dataset")
print_key_value("Curve Batch Size", dataset_config["curve_batch_size"], value_color=Fore.YELLOW)
print_key_value("Point Stride", dataset_config["point_stride"], value_color=Fore.YELLOW)
print_key_value("Maximum Points Per Curve", dataset_config["maximum_points_per_curve"], value_color=Fore.YELLOW)
print_key_value("Num Workers", dataset_config["num_workers"], value_color=Fore.YELLOW)
print_key_value("Pin Memory", dataset_config["pin_memory"], value_color=Fore.YELLOW)
# Print Model Configuration
print_subsection_header("Model")
print_key_value("Input Size", model_config["input_size"], value_color=Fore.YELLOW)
print_key_value("Output Size", model_config.get("output_size", 1), value_color=Fore.YELLOW)
# Resolve Model-Specific Configuration Fields
normalized_model_type = str(experiment_config["model_type"]).strip().lower()
# Print Feedforward Network Configuration
if normalized_model_type in ["feedforward", "periodic_mlp"]:
print_key_value("Hidden Layers", model_config["hidden_size"], value_color=Fore.YELLOW)
print_key_value("Activation", model_config["activation_name"], value_color=Fore.YELLOW)
print_key_value("Dropout Probability", model_config["dropout_probability"], value_color=Fore.YELLOW)
print_key_value("Use Layer Norm", model_config["use_layer_norm"], value_color=Fore.YELLOW)
# Print Periodic MLP Configuration
if normalized_model_type == "periodic_mlp":
print_key_value("Harmonic Order", model_config["harmonic_order"], value_color=Fore.YELLOW)
print_key_value("Include Raw Angle Feature", model_config.get("include_raw_angle_feature", True), value_color=Fore.YELLOW)
# Print Harmonic Regression Configuration
elif normalized_model_type == "harmonic_regression":
print_key_value("Harmonic Order", model_config["harmonic_order"], value_color=Fore.YELLOW)
print_key_value("Coefficient Mode", model_config.get("coefficient_mode", "static"), value_color=Fore.YELLOW)
# Print Residual Harmonic Regression Configuration
elif normalized_model_type == "residual_harmonic_mlp":
print_key_value("Harmonic Order", model_config["harmonic_order"], value_color=Fore.YELLOW)
print_key_value("Coefficient Mode", model_config.get("coefficient_mode", "static"), value_color=Fore.YELLOW)
print_key_value("Residual Hidden Layers", model_config["residual_hidden_size"], value_color=Fore.YELLOW)
print_key_value("Residual Activation", model_config.get("residual_activation_name", "GELU"), value_color=Fore.YELLOW)
print_key_value("Residual Dropout Probability", model_config.get("residual_dropout_probability", 0.10), value_color=Fore.YELLOW)
print_key_value("Residual Use Layer Norm", model_config.get("residual_use_layer_norm", True), value_color=Fore.YELLOW)
print_key_value("Freeze Structured Branch", model_config.get("freeze_structured_branch", False), value_color=Fore.YELLOW)
# Print Model Config Keys
else: print_key_value("Model Config Keys", sorted(model_config.keys()), value_color=Fore.YELLOW)
# Print Optimization Configuration
print_subsection_header("Optimization")
print_key_value("Learning Rate", optimization_config["learning_rate"], value_color=Fore.YELLOW)
print_key_value("Weight Decay", optimization_config["weight_decay"], value_color=Fore.YELLOW)
print_key_value("Min Epochs", optimization_config["min_epochs"], value_color=Fore.YELLOW)
print_key_value("Max Epochs", optimization_config["max_epochs"], value_color=Fore.YELLOW)
print_key_value("Patience", optimization_config["patience"], value_color=Fore.YELLOW)
print_key_value("Min Delta", optimization_config["min_delta"], value_color=Fore.YELLOW)
print_key_value("Log Every N Steps", optimization_config["log_every_n_steps"], value_color=Fore.YELLOW)
print_key_value("Fast Dev Run", optimization_config["fast_dev_run"], value_color=Fore.YELLOW)
print_key_value("Deterministic", optimization_config["deterministic"], value_color=Fore.YELLOW)
# Print Runtime Configuration
print_subsection_header("Runtime")
print_key_value("Accelerator", runtime_config["accelerator"], value_color=Fore.YELLOW)
print_key_value("Devices", runtime_config["devices"], value_color=Fore.YELLOW)
print_key_value("Precision", runtime_config["precision"], value_color=Fore.YELLOW)
print_key_value("Benchmark", runtime_config["benchmark"], value_color=Fore.YELLOW)
print_key_value("Non-Blocking Transfer", runtime_config["use_non_blocking_transfer"], value_color=Fore.YELLOW)
def print_dataset_summary(datamodule: TransmissionErrorDataModule, input_feature_dim: int, target_feature_dim: int) -> None:
""" Print Dataset Summary """
# Get Dataset Split Summary From The DataModule
dataset_split_summary = datamodule.get_dataset_split_summary()
# Print Dataset Summary Statistics
print_section_header("Dataset Summary")
print_key_value("Input Feature Dim", input_feature_dim, value_color=Fore.YELLOW)
print_key_value("Target Feature Dim", target_feature_dim, value_color=Fore.YELLOW)
print_key_value("Train Curves", dataset_split_summary.train_curve_count, value_color=Fore.YELLOW)
print_key_value("Validation Curves", dataset_split_summary.validation_curve_count, value_color=Fore.YELLOW)
print_key_value("Test Curves", dataset_split_summary.test_curve_count, value_color=Fore.YELLOW)
print_key_value("Point Stride", datamodule.point_stride, value_color=Fore.YELLOW)
print_key_value("Maximum Points Per Curve", datamodule.maximum_points_per_curve, value_color=Fore.YELLOW)
print_key_value("Persistent Workers", datamodule.num_workers > 0, value_color=Fore.YELLOW)
def print_model_summary(regression_backbone: torch.nn.Module) -> None:
""" Print Model Summary """
# Compute Parameter Counts
trainable_parameter_count = sum(parameter.numel() for parameter in regression_backbone.parameters() if parameter.requires_grad)
total_parameter_count = sum(parameter.numel() for parameter in regression_backbone.parameters())
frozen_parameter_count = total_parameter_count - trainable_parameter_count
# Print Compact Model Summary
print_section_header("Model Summary")
print_key_value("Backbone", regression_backbone.__class__.__name__, value_color=Fore.YELLOW)
print_key_value("Trainable Parameters", trainable_parameter_count, value_color=Fore.YELLOW)
print_key_value("Frozen Parameters", frozen_parameter_count, value_color=Fore.YELLOW)
print_key_value("Total Parameters", total_parameter_count, value_color=Fore.YELLOW)
def print_normalization_statistics_summary(normalization_statistics) -> None:
""" Print Normalization Statistics Summary """
print_section_header("Normalization Statistics")
# Print Input Normalization Statistics
print_subsection_header("Input Features")
print_feature_statistics(
feature_name_list=INPUT_FEATURE_NAME_LIST,
mean_value_list=normalization_statistics.input_feature_mean.tolist(),
std_value_list=normalization_statistics.input_feature_std.tolist(),
)
# Print Target Normalization Statistics
print_subsection_header("Target")
print_feature_statistics(
feature_name_list=TARGET_FEATURE_NAME_LIST,
mean_value_list=normalization_statistics.target_mean.tolist(),
std_value_list=normalization_statistics.target_std.tolist(),
)
def print_runtime_summary(runtime_config: dict[str, object]) -> None:
""" Print Runtime Summary """
# Print Runtime Configuration Summary
print_section_header("Runtime Summary")
print_key_value("Configured Accelerator", runtime_config["accelerator"], value_color=Fore.YELLOW)
print_key_value("Configured Devices", runtime_config["devices"], value_color=Fore.YELLOW)
print_key_value("Configured Precision", runtime_config["precision"], value_color=Fore.YELLOW)
print_key_value("cuDNN Benchmark", runtime_config["benchmark"], value_color=Fore.YELLOW)
print_key_value("Non-Blocking Transfer", runtime_config["use_non_blocking_transfer"], value_color=Fore.YELLOW)
print_key_value("CUDA Available", torch.cuda.is_available(), value_color=Fore.YELLOW)
print_key_value("CUDA Device Count", torch.cuda.device_count(), value_color=Fore.YELLOW)
# Print CUDA Device Name If Available, Otherwise Print Warning About CPU Training
if torch.cuda.is_available(): print_key_value("CUDA Device Name", torch.cuda.get_device_name(0), value_color=Fore.YELLOW)
else: print_warning_message("CUDA is not available -> training will run on CPU")
def print_output_artifact_summary(output_directory: Path, logger: TensorBoardLogger, best_model_path: str) -> None:
""" Print Output Artifact Summary """
# Print Output Artifact Summary
print_section_header("Output Artifacts")
print_key_value("Output Directory", output_directory, value_color=Fore.YELLOW)
print_key_value("Checkpoint Directory", output_directory / "checkpoints", value_color=Fore.YELLOW)
print_key_value("Config Snapshot", output_directory / shared_training_infrastructure.COMMON_TRAINING_CONFIG_FILENAME, value_color=Fore.YELLOW)
print_key_value("Metrics Snapshot", output_directory / shared_training_infrastructure.COMMON_METRICS_FILENAME, value_color=Fore.YELLOW)
print_key_value("Run Metadata", output_directory / shared_training_infrastructure.COMMON_RUN_METADATA_FILENAME, value_color=Fore.YELLOW)
print_key_value("Run Report", output_directory / shared_training_infrastructure.COMMON_RUN_REPORT_FILENAME, value_color=Fore.YELLOW)
print_key_value(
"Family Best Registry",
shared_training_infrastructure.build_family_registry_directory(output_directory.parent.name) / shared_training_infrastructure.FAMILY_BEST_FILENAME,
value_color=Fore.YELLOW,
)
print_key_value(
"Program Best Registry",
shared_training_infrastructure.PROGRAM_REGISTRY_OUTPUT_ROOT / shared_training_infrastructure.PROGRAM_BEST_FILENAME,
value_color=Fore.YELLOW,
)
# Print Logger Artifact Summary
if logger.log_dir: print_key_value("TensorBoard Log Directory", logger.log_dir, value_color=Fore.YELLOW)
print_key_value("Best Checkpoint", best_model_path, value_color=Fore.YELLOW)
def build_metric_interpretation(metric_dictionary: dict[str, object], metric_prefix: str) -> str:
""" Build Metric Interpretation """
# Extract MAE And RMSE Values From The Metric Dictionary For The Given Prefix
metric_mae = metric_dictionary.get(f"{metric_prefix}_mae")
metric_rmse = metric_dictionary.get(f"{metric_prefix}_rmse")
# Interpret The Metrics Based On Their Values
if isinstance(metric_mae, (int, float)) and isinstance(metric_rmse, (int, float)):
return (
f"The held-out {metric_prefix} error stayed finite with MAE={metric_mae:.6f} deg and "
f"RMSE={metric_rmse:.6f} deg, which indicates a numerically stable baseline run."
)
return f"The held-out {metric_prefix} metrics were not fully available in serialized form, so only raw metric files should be trusted."
def save_training_test_report(output_directory: Path, training_config: dict, metrics_snapshot_dictionary: dict[str, object]) -> None:
"""Write the Markdown training and test summary for one completed run.
Args:
output_directory: Run artifact directory where the report will be saved.
training_config: Resolved training configuration used for the run.
metrics_snapshot_dictionary: Serialized metrics payload returned by the
shared training infrastructure.
"""
# Extract Relevant Information From The Metrics Snapshot Dictionary To Build The Report
experiment_dictionary = metrics_snapshot_dictionary["experiment"]
dataset_split_dictionary = metrics_snapshot_dictionary["dataset_split"]
validation_metric_dictionary = metrics_snapshot_dictionary["validation_metrics"]
test_metric_dictionary = metrics_snapshot_dictionary["test_metrics"]
model_family_display_name = str(experiment_dictionary["model_family"]).replace("_", " ").title()
# Build The Report As A List Of Lines To Be Joined With Newlines For Output
report_line_list = [
f"# {model_family_display_name} Training And Testing Report",
"",
"## Overview",
"",
f"- Run Name: `{experiment_dictionary['run_name']}`",
f"- Model Family: `{experiment_dictionary['model_family']}`",
f"- Model Type: `{experiment_dictionary['model_type']}`",
f"- Best Checkpoint: `{metrics_snapshot_dictionary['artifacts']['best_checkpoint_path']}`",
"",
"## Dataset Split",
"",
f"- Train Curves: `{dataset_split_dictionary['train_curve_count']}`",
f"- Validation Curves: `{dataset_split_dictionary['validation_curve_count']}`",
f"- Test Curves: `{dataset_split_dictionary['test_curve_count']}`",
"",
"## Validation Metrics",
"",
]
# Add Each Validation Metric To The Report
for metric_name, metric_value in validation_metric_dictionary.items():
report_line_list.append(f"- {metric_name}: `{format_terminal_value(metric_value)}`")
# Add Test Metrics Section Header
report_line_list.extend([
"",
"## Test Metrics",
"",
])
# Add Each Test Metric Line To The Report
for metric_name, metric_value in test_metric_dictionary.items():
report_line_list.append(f"- {metric_name}: `{format_terminal_value(metric_value)}`")
# Add Interpretation Of The Metrics To The Report
report_line_list.extend([
"",
"## Interpretation",
"",
build_metric_interpretation(validation_metric_dictionary, "val"),
build_metric_interpretation(test_metric_dictionary, "test"),
])
# Save The Report To The Output Directory As A Markdown File
report_path = output_directory / shared_training_infrastructure.COMMON_RUN_REPORT_FILENAME
report_path.write_text("\n".join(report_line_list) + "\n", encoding="utf-8")
[docs]
def load_training_config(config_path: str | Path = DEFAULT_CONFIG_PATH) -> dict:
"""Load one training configuration file through the shared infrastructure.
Args:
config_path: Training YAML path to load.
Returns:
Parsed training configuration dictionary.
"""
# Resolve Config Path
return shared_training_infrastructure.load_training_config(config_path)
[docs]
def resolve_runtime_config(training_config: dict) -> dict[str, object]:
"""Resolve runtime execution options for the current training configuration.
Args:
training_config: Parsed training configuration dictionary.
Returns:
Runtime configuration dictionary after repository-specific checks and
warning-driven adjustments.
"""
# Initialize Default Runtime Configuration
runtime_config = shared_training_infrastructure.resolve_runtime_config(training_config)
# Disable Benchmark In Deterministic Mode
if bool(training_config["training"]["deterministic"]) and bool(runtime_config["benchmark"]):
print_warning_message("Deterministic mode is enabled -> disabling cuDNN benchmark to avoid conflicting runtime behavior")
runtime_config["benchmark"] = False
# Warn If Non-Blocking Transfer Has Limited Value
if bool(runtime_config["use_non_blocking_transfer"]) and not bool(training_config["dataset"]["pin_memory"]):
print_warning_message("Non-blocking transfer is enabled but pin_memory is disabled -> host-to-device copy overlap may be limited")
return runtime_config
[docs]
def train_feedforward_network(config_path: str | Path = DEFAULT_CONFIG_PATH) -> None:
"""Run the full feedforward TE training workflow for one configuration.
This entry point loads the configuration, prepares artifact directories,
initializes the datamodule and model stack, runs training, validation, and
held-out testing, then serializes the common artifact contract used by the
repository registries and campaign tooling.
Args:
config_path: Path to the YAML training configuration to execute.
"""
# Load Training Configuration
training_config = shared_training_infrastructure.prepare_output_artifact_training_config(load_training_config(config_path))
experiment_identity = shared_training_infrastructure.resolve_experiment_identity(training_config)
runtime_config = resolve_runtime_config(training_config)
# Resolve Output Directory
output_directory = shared_training_infrastructure.resolve_output_directory(training_config)
output_directory.mkdir(parents=True, exist_ok=True)
# Save Configuration Snapshot
shared_training_infrastructure.save_training_config_snapshot(training_config, output_directory)
shared_training_infrastructure.save_run_metadata_snapshot(training_config, output_directory)
# Initialize Shared Training Components
datamodule, regression_backbone, regression_module, normalization_statistics = shared_training_infrastructure.initialize_training_components(training_config)
input_feature_dim = datamodule.get_input_feature_dim()
target_feature_dim = datamodule.get_target_feature_dim()
parameter_summary = shared_training_infrastructure.summarize_model_parameters(regression_backbone)
# Print Training Summary
print_training_configuration_summary(training_config)
print_dataset_summary(datamodule, input_feature_dim, target_feature_dim)
print_model_summary(regression_backbone)
print_normalization_statistics_summary(normalization_statistics)
print_runtime_summary(runtime_config)
# Create Logger
logger = TensorBoardLogger(save_dir=str(output_directory / "logs"), name="", version="")
# Checkpoint Callback To Save Best Model Based On Validation MAE, As Well As The Last Model For Resuming Training If Needed
checkpoint_callback = ModelCheckpoint(
dirpath=str(output_directory / "checkpoints"),
filename=f"{experiment_identity.model_type}" + "-{epoch:03d}-{val_mae:.8f}",
monitor="val_mae",
mode="min",
save_top_k=1,
save_last=True,
)
# Early Stopping Callback To Stop Training If Validation MAE Does Not Improve For A Certain Number Of Epochs
early_stopping_callback = EarlyStopping(
monitor="val_mae",
mode="min",
patience=int(training_config["training"]["patience"]),
min_delta=float(training_config["training"]["min_delta"]),
verbose=True,
)
# Progress Bar Callback To Display Training Progress
progress_bar_callback = TQDMProgressBar(
refresh_rate=PROGRESS_BAR_REFRESH_RATE,
leave=True,
)
# Create Trainer - Suppress Lightning Internal Info Logs To Reduce Terminal Clutter During Training
with suppress_lightning_info_logs():
trainer = Trainer(
accelerator=runtime_config["accelerator"],
devices=runtime_config["devices"],
precision=runtime_config["precision"],
benchmark=bool(runtime_config["benchmark"]),
min_epochs=int(training_config["training"]["min_epochs"]),
max_epochs=int(training_config["training"]["max_epochs"]),
log_every_n_steps=int(training_config["training"]["log_every_n_steps"]),
deterministic=bool(training_config["training"]["deterministic"]),
fast_dev_run=bool(training_config["training"]["fast_dev_run"]),
enable_model_summary=False,
enable_progress_bar=True,
logger=logger,
callbacks=[
checkpoint_callback,
early_stopping_callback,
progress_bar_callback,
],
)
# Start Training
print_section_header("Training Loop")
print_info_message("Starting Lightning fit loop")
trainer.fit(regression_module, datamodule=datamodule)
print_success_message("Training loop completed")
# Start Validation
print_section_header("Validation Loop")
print_info_message("Starting final Lightning validation loop")
validation_metric_list = trainer.validate(regression_module, datamodule=datamodule)
print_success_message("Validation loop completed")
# Save Best Checkpoint Path
best_model_path = checkpoint_callback.best_model_path
if not best_model_path: best_model_path = "Best checkpoint not available | fast_dev_run or checkpointing disabled"
# Load Best Checkpoint For Final Held-Out Evaluation
best_regression_module = regression_module
if Path(best_model_path).exists():
# Load The Best Checkpoint For Reproducible Validation And Test Evaluation
print_section_header("Best Checkpoint Evaluation")
print_info_message("Loading best checkpoint for reproducible validation and test evaluation")
best_regression_module = TransmissionErrorRegressionModule.load_from_checkpoint(
checkpoint_path=best_model_path,
regression_model=create_model(
model_type=str(training_config["experiment"]["model_type"]),
model_configuration=training_config["model"],
),
input_feature_dim=input_feature_dim,
target_feature_dim=target_feature_dim,
normalization_statistics=normalization_statistics,
)
validation_metric_list = trainer.validate(best_regression_module, datamodule=datamodule)
print_success_message("Best-checkpoint validation loop completed")
# Start Testing
print_section_header("Test Loop")
print_info_message("Starting held-out Lightning test loop")
test_metric_list = trainer.test(best_regression_module, datamodule=datamodule)
print_success_message("Test loop completed")
# Save Best Checkpoint Path To File For Easy Reference
best_model_path_file = output_directory / "best_checkpoint_path.txt"
best_model_path_file.write_text(best_model_path, encoding="utf-8")
# Save Machine-Readable Metrics And Human-Readable Report
metrics_snapshot_dictionary = shared_training_infrastructure.build_common_metrics_snapshot(
training_config,
config_path,
output_directory,
datamodule,
parameter_summary,
runtime_config,
best_model_path,
validation_metric_list,
test_metric_list,
)
# Save Metrics Snapshot
shared_training_infrastructure.save_common_metrics_snapshot(
metrics_snapshot_dictionary,
output_directory,
)
# Save Training/Test Report
save_training_test_report(
output_directory,
training_config,
metrics_snapshot_dictionary,
)
# Update Best-Result Registries
family_best_entry = shared_training_infrastructure.update_family_registry(metrics_snapshot_dictionary)
shared_training_infrastructure.update_program_registry(family_best_entry)
# Save Last Logger Configuration
if logger.log_dir:
logger_directory = Path(logger.log_dir)
if logger_directory.exists(): shutil.copyfile(best_model_path_file, logger_directory / "best_checkpoint_path.txt")
# Print Output Artifact Summary
print_output_artifact_summary(output_directory, logger, best_model_path)
print_success_message(f"{experiment_identity.model_type} training workflow completed")
[docs]
def parse_command_line_arguments() -> argparse.Namespace:
"""Parse the command-line arguments for the training entry point.
Returns:
Parsed command-line namespace containing the selected configuration
path.
"""
# Initialize Argument Parser
argument_parser = argparse.ArgumentParser(description="Train the configured static TE neural baseline.")
# Add Config Path Argument
argument_parser.add_argument(
"--config-path",
type=Path,
default=DEFAULT_CONFIG_PATH,
help="Path to the YAML training configuration file.",
)
return argument_parser.parse_args()
def main() -> None:
"""Run the command-line training entry point."""
# Parse Command Line Arguments
command_line_arguments = parse_command_line_arguments()
# Train Static Neural Model
train_feedforward_network(command_line_arguments.config_path)
if __name__ == "__main__":
main()