Training API (Autogenerated)¶
This page is generated from Python docstrings via mkdocstrings.
train¶
One-liner training function for quick experimentation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | World model to train. | required |
data | BatchProvider | Any | BatchProvider or iterable yielding Batch/dict. | required |
total_steps | int | None | Number of training steps. | None |
batch_size | int | Batch size. | 16 |
sequence_length | int | Sequence length for trajectory sampling. | 50 |
learning_rate | float | Learning rate. | 0.0003 |
grad_clip | float | Gradient clipping value. | 100.0 |
output_dir | str | Directory for outputs. | './outputs' |
device | str | Device to train on. | 'auto' |
**kwargs | Any | Additional config options. | {} |
Returns:
| Type | Description |
|---|---|
Module | Trained model. |
Example
from worldflux import create_world_model from worldflux.training import train, ReplayBuffer
model = create_world_model("dreamerv3:size12m") buffer = ReplayBuffer.load("data.npz") trained_model = train(model, buffer, total_steps=50_000)
Trainer¶
HuggingFace-style trainer for WorldFlux.
Provides a simple interface for training world models with: - Automatic device placement - Gradient clipping - Checkpointing - Logging (console and optional wandb) - Learning rate scheduling
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model | Module | World model to train (must implement loss(batch)). | required |
config | TrainingConfig | None | Training configuration. | None |
callbacks | list[Callback] | None | List of callbacks for logging/checkpointing. | None |
optimizer | Optimizer | None | Optional custom optimizer. | None |
scheduler | LRScheduler | None | Optional learning rate scheduler. | None |
Example
from worldflux import create_world_model from worldflux.training import Trainer, TrainingConfig, ReplayBuffer
model = create_world_model("dreamerv3:size12m") buffer = ReplayBuffer.load("data.npz") config = TrainingConfig(total_steps=50_000, batch_size=32)
trainer = Trainer(model, config) trainer.train(buffer)
add_callback(callback) ¶
Register a callback after trainer construction.
train(data, num_steps=None, resume_from=None) ¶
Train the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data | BatchProvider | Any | BatchProvider or iterable yielding Batch/dict. | required |
num_steps | int | None | Number of steps to train. If None, uses config.total_steps. | None |
resume_from | str | None | Path to checkpoint to resume from. | None |
Returns:
| Type | Description |
|---|---|
Module | Trained model. |
runtime_profile() ¶
Return lightweight runtime profiling metrics for DX instrumentation.
evaluate(data, num_batches=10) ¶
Evaluate the model on data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data | BatchProvider | Any | ReplayBuffer containing evaluation data. | required |
num_batches | int | Number of batches to evaluate. | 10 |
Returns:
| Type | Description |
|---|---|
dict[str, float] | Dictionary of average metrics. |
save_checkpoint(path) ¶
Save training checkpoint atomically with validation.
Uses atomic write pattern: write to temp file, validate, then rename. This prevents corrupted checkpoints if disk fills or process is killed.
load_checkpoint(path) ¶
Load training checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path | str | Path to checkpoint file. | required |
Raises:
| Type | Description |
|---|---|
CheckpointError | If checkpoint file is missing or corrupted. |
TrainingConfig¶
Configuration for training world models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
total_steps | int | Total number of training steps. | 100000 |
batch_size | int | Batch size for training. | 16 |
sequence_length | int | Sequence length for trajectory sampling. | 50 |
learning_rate | float | Learning rate for optimizer. | 0.0003 |
grad_clip | float | Maximum gradient norm for clipping. | 100.0 |
weight_decay | float | Weight decay for optimizer. | 0.0 |
warmup_steps | int | Number of warmup steps for learning rate scheduler. | 0 |
log_interval | int | Interval (in steps) for logging metrics. | 100 |
eval_interval | int | Interval (in steps) for evaluation. | 1000 |
save_interval | int | Interval (in steps) for saving checkpoints. | 10000 |
output_dir | str | Directory for saving outputs (checkpoints, logs). | './outputs' |
device | str | Device to train on ('cuda', 'cpu', 'auto'). | 'auto' |
seed | int | Random seed for reproducibility. | 42 |
mixed_precision | bool | Whether to use mixed precision training. | False |
num_workers | int | Number of workers for data loading. | 0 |
prefetch_factor | int | Number of batches to prefetch per worker. | 2 |
Example
config = TrainingConfig(total_steps=100_000, batch_size=32) config.save("training_config.json") loaded = TrainingConfig.load("training_config.json")
to_dict() ¶
Convert config to dictionary.
from_dict(d) classmethod ¶
Create config from dictionary.
save(path) ¶
Save config to JSON file.
load(path) classmethod ¶
Load config from JSON file.
resolve_device() ¶
Resolve 'auto' device to actual device.
with_updates(**kwargs) ¶
Return a new config with updated values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs | Any | Configuration values to update. | {} |
Returns:
| Type | Description |
|---|---|
TrainingConfig | New TrainingConfig with updated values. |
Raises:
| Type | Description |
|---|---|
ConfigurationError | If updated values are invalid. |
effective_total_steps() ¶
Total steps used by trainer under current mode.
effective_batch_size() ¶
Batch size used by trainer under current mode.
effective_sequence_length() ¶
Sequence length used by trainer under current mode.
ReplayBuffer¶
Efficient trajectory storage for world model training.
Stores episodes as contiguous arrays and supports efficient random sampling of trajectory segments for training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
capacity | int | Maximum number of transitions to store. | required |
obs_shape | tuple[int, ...] | Shape of observations (e.g., (3, 64, 64) for images). | required |
action_dim | int | Dimension of action space. | required |
obs_dtype | type | NumPy dtype for observations (default: float32). | float32 |
Example
buffer = ReplayBuffer(capacity=100_000, obs_shape=(3, 64, 64), action_dim=6) buffer.add_episode(obs, actions, rewards, dones) batch = buffer.sample(batch_size=32, seq_len=50)
Raises:
| Type | Description |
|---|---|
ConfigurationError | If capacity, obs_shape, or action_dim are invalid. |
num_episodes property ¶
Return number of complete episodes stored.
add_episode(obs, actions, rewards, dones=None) ¶
Add a complete episode to the buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs | ndarray | Observations of shape [episode_len, *obs_shape]. | required |
actions | ndarray | Actions of shape [episode_len, action_dim]. | required |
rewards | ndarray | Rewards of shape [episode_len]. | required |
dones | ndarray | None | Done flags of shape [episode_len]. If None, last step is done. | None |
sample(batch_size, seq_len, device='cpu') ¶
Sample random trajectory segments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size | int | Number of trajectory segments to sample. | required |
seq_len | int | Length of each trajectory segment. | required |
device | str | device | Device to place tensors on. | 'cpu' |
Returns:
| Type | Description |
|---|---|
Batch | Dictionary with keys: - obs: [batch_size, seq_len, *obs_shape] - actions: [batch_size, seq_len, action_dim] - rewards: [batch_size, seq_len] - continues: [batch_size, seq_len] (1 - dones) |
batch_layout() ¶
Return explicit axis layout for sampled batches.
save(path) ¶
Save buffer to disk.
load(path) classmethod ¶
Load buffer from disk with schema validation.
from_trajectories(trajectories, capacity=None) classmethod ¶
Create buffer from list of trajectory dictionaries.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trajectories | list[dict[str, ndarray]] | List of dicts with keys 'obs', 'actions', 'rewards', 'dones'. | required |
capacity | int | None | Buffer capacity. If None, uses total trajectory length. | None |
Callback¶
Bases: ABC
Base class for training callbacks.
on_train_begin(trainer) ¶
Called at the start of training.
on_train_end(trainer) ¶
Called at the end of training.
on_epoch_begin(trainer) ¶
Called at the start of each epoch.
on_epoch_end(trainer) ¶
Called at the end of each epoch.
on_step_begin(trainer) ¶
Called before each training step.
on_step_end(trainer) ¶
Called after each training step.
LoggingCallback¶
Bases: Callback
Callback for logging training metrics.
Logs to console and optionally to wandb.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_interval | int | Steps between log outputs. | 100 |
use_wandb | bool | Whether to log to wandb. | False |
wandb_project | str | None | wandb project name. | None |
wandb_run_name | str | None | wandb run name. | None |
Example
callback = LoggingCallback(log_interval=100, use_wandb=True) trainer = Trainer(model, config, callbacks=[callback])
CheckpointCallback¶
Bases: Callback
Callback for saving model checkpoints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
save_interval | int | Steps between checkpoint saves. | 10000 |
output_dir | str | Directory to save checkpoints. | './outputs' |
save_best | bool | Whether to save the best model (lowest loss). | True |
max_checkpoints | int | None | Maximum number of checkpoints to keep. | 5 |
Example
callback = CheckpointCallback( ... save_interval=10000, ... output_dir="./checkpoints", ... save_best=True, ... )
EarlyStoppingCallback¶
Bases: Callback
Callback for early stopping based on loss plateau.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
patience | int | Number of steps to wait before stopping. | 5000 |
min_delta | float | Minimum improvement to reset patience. | 0.0001 |
monitor | str | Metric to monitor (default: "loss"). | 'loss' |
Example
callback = EarlyStoppingCallback(patience=5000, min_delta=1e-4)