Skip to content

Training API (Autogenerated)

This page is generated from Python docstrings via mkdocstrings.

train

One-liner training function for quick experimentation.

Parameters:

Name Type Description Default
model Module

World model to train.

required
data BatchProvider | Any

BatchProvider or iterable yielding Batch/dict.

required
total_steps int | None

Number of training steps.

None
batch_size int

Batch size.

16
sequence_length int

Sequence length for trajectory sampling.

50
learning_rate float

Learning rate.

0.0003
grad_clip float

Gradient clipping value.

100.0
output_dir str

Directory for outputs.

'./outputs'
device str

Device to train on.

'auto'
**kwargs Any

Additional config options.

{}

Returns:

Type Description
Module

Trained model.

Example

from worldflux import create_world_model from worldflux.training import train, ReplayBuffer

model = create_world_model("dreamerv3:size12m") buffer = ReplayBuffer.load("data.npz") trained_model = train(model, buffer, total_steps=50_000)

Trainer

HuggingFace-style trainer for WorldFlux.

Provides a simple interface for training world models with: - Automatic device placement - Gradient clipping - Checkpointing - Logging (console and optional wandb) - Learning rate scheduling

Parameters:

Name Type Description Default
model Module

World model to train (must implement loss(batch)).

required
config TrainingConfig | None

Training configuration.

None
callbacks list[Callback] | None

List of callbacks for logging/checkpointing.

None
optimizer Optimizer | None

Optional custom optimizer.

None
scheduler LRScheduler | None

Optional learning rate scheduler.

None
Example

from worldflux import create_world_model from worldflux.training import Trainer, TrainingConfig, ReplayBuffer

model = create_world_model("dreamerv3:size12m") buffer = ReplayBuffer.load("data.npz") config = TrainingConfig(total_steps=50_000, batch_size=32)

trainer = Trainer(model, config) trainer.train(buffer)

add_callback(callback)

Register a callback after trainer construction.

train(data, num_steps=None, resume_from=None)

Train the model.

Parameters:

Name Type Description Default
data BatchProvider | Any

BatchProvider or iterable yielding Batch/dict.

required
num_steps int | None

Number of steps to train. If None, uses config.total_steps.

None
resume_from str | None

Path to checkpoint to resume from.

None

Returns:

Type Description
Module

Trained model.

runtime_profile()

Return lightweight runtime profiling metrics for DX instrumentation.

evaluate(data, num_batches=10)

Evaluate the model on data.

Parameters:

Name Type Description Default
data BatchProvider | Any

ReplayBuffer containing evaluation data.

required
num_batches int

Number of batches to evaluate.

10

Returns:

Type Description
dict[str, float]

Dictionary of average metrics.

save_checkpoint(path)

Save training checkpoint atomically with validation.

Uses atomic write pattern: write to temp file, validate, then rename. This prevents corrupted checkpoints if disk fills or process is killed.

load_checkpoint(path)

Load training checkpoint.

Parameters:

Name Type Description Default
path str

Path to checkpoint file.

required

Raises:

Type Description
CheckpointError

If checkpoint file is missing or corrupted.

TrainingConfig

Configuration for training world models.

Parameters:

Name Type Description Default
total_steps int

Total number of training steps.

100000
batch_size int

Batch size for training.

16
sequence_length int

Sequence length for trajectory sampling.

50
learning_rate float

Learning rate for optimizer.

0.0003
grad_clip float

Maximum gradient norm for clipping.

100.0
weight_decay float

Weight decay for optimizer.

0.0
warmup_steps int

Number of warmup steps for learning rate scheduler.

0
log_interval int

Interval (in steps) for logging metrics.

100
eval_interval int

Interval (in steps) for evaluation.

1000
save_interval int

Interval (in steps) for saving checkpoints.

10000
output_dir str

Directory for saving outputs (checkpoints, logs).

'./outputs'
device str

Device to train on ('cuda', 'cpu', 'auto').

'auto'
seed int

Random seed for reproducibility.

42
mixed_precision bool

Whether to use mixed precision training.

False
num_workers int

Number of workers for data loading.

0
prefetch_factor int

Number of batches to prefetch per worker.

2
Example

config = TrainingConfig(total_steps=100_000, batch_size=32) config.save("training_config.json") loaded = TrainingConfig.load("training_config.json")

to_dict()

Convert config to dictionary.

from_dict(d) classmethod

Create config from dictionary.

save(path)

Save config to JSON file.

load(path) classmethod

Load config from JSON file.

resolve_device()

Resolve 'auto' device to actual device.

with_updates(**kwargs)

Return a new config with updated values.

Parameters:

Name Type Description Default
**kwargs Any

Configuration values to update.

{}

Returns:

Type Description
TrainingConfig

New TrainingConfig with updated values.

Raises:

Type Description
ConfigurationError

If updated values are invalid.

effective_total_steps()

Total steps used by trainer under current mode.

effective_batch_size()

Batch size used by trainer under current mode.

effective_sequence_length()

Sequence length used by trainer under current mode.

ReplayBuffer

Efficient trajectory storage for world model training.

Stores episodes as contiguous arrays and supports efficient random sampling of trajectory segments for training.

Parameters:

Name Type Description Default
capacity int

Maximum number of transitions to store.

required
obs_shape tuple[int, ...]

Shape of observations (e.g., (3, 64, 64) for images).

required
action_dim int

Dimension of action space.

required
obs_dtype type

NumPy dtype for observations (default: float32).

float32
Example

buffer = ReplayBuffer(capacity=100_000, obs_shape=(3, 64, 64), action_dim=6) buffer.add_episode(obs, actions, rewards, dones) batch = buffer.sample(batch_size=32, seq_len=50)

Raises:

Type Description
ConfigurationError

If capacity, obs_shape, or action_dim are invalid.

num_episodes property

Return number of complete episodes stored.

add_episode(obs, actions, rewards, dones=None)

Add a complete episode to the buffer.

Parameters:

Name Type Description Default
obs ndarray

Observations of shape [episode_len, *obs_shape].

required
actions ndarray

Actions of shape [episode_len, action_dim].

required
rewards ndarray

Rewards of shape [episode_len].

required
dones ndarray | None

Done flags of shape [episode_len]. If None, last step is done.

None

sample(batch_size, seq_len, device='cpu')

Sample random trajectory segments.

Parameters:

Name Type Description Default
batch_size int

Number of trajectory segments to sample.

required
seq_len int

Length of each trajectory segment.

required
device str | device

Device to place tensors on.

'cpu'

Returns:

Type Description
Batch

Dictionary with keys: - obs: [batch_size, seq_len, *obs_shape] - actions: [batch_size, seq_len, action_dim] - rewards: [batch_size, seq_len] - continues: [batch_size, seq_len] (1 - dones)

batch_layout()

Return explicit axis layout for sampled batches.

save(path)

Save buffer to disk.

load(path) classmethod

Load buffer from disk with schema validation.

from_trajectories(trajectories, capacity=None) classmethod

Create buffer from list of trajectory dictionaries.

Parameters:

Name Type Description Default
trajectories list[dict[str, ndarray]]

List of dicts with keys 'obs', 'actions', 'rewards', 'dones'.

required
capacity int | None

Buffer capacity. If None, uses total trajectory length.

None

Callback

Bases: ABC

Base class for training callbacks.

on_train_begin(trainer)

Called at the start of training.

on_train_end(trainer)

Called at the end of training.

on_epoch_begin(trainer)

Called at the start of each epoch.

on_epoch_end(trainer)

Called at the end of each epoch.

on_step_begin(trainer)

Called before each training step.

on_step_end(trainer)

Called after each training step.

LoggingCallback

Bases: Callback

Callback for logging training metrics.

Logs to console and optionally to wandb.

Parameters:

Name Type Description Default
log_interval int

Steps between log outputs.

100
use_wandb bool

Whether to log to wandb.

False
wandb_project str | None

wandb project name.

None
wandb_run_name str | None

wandb run name.

None
Example

callback = LoggingCallback(log_interval=100, use_wandb=True) trainer = Trainer(model, config, callbacks=[callback])

CheckpointCallback

Bases: Callback

Callback for saving model checkpoints.

Parameters:

Name Type Description Default
save_interval int

Steps between checkpoint saves.

10000
output_dir str

Directory to save checkpoints.

'./outputs'
save_best bool

Whether to save the best model (lowest loss).

True
max_checkpoints int | None

Maximum number of checkpoints to keep.

5
Example

callback = CheckpointCallback( ... save_interval=10000, ... output_dir="./checkpoints", ... save_best=True, ... )

EarlyStoppingCallback

Bases: Callback

Callback for early stopping based on loss plateau.

Parameters:

Name Type Description Default
patience int

Number of steps to wait before stopping.

5000
min_delta float

Minimum improvement to reset patience.

0.0001
monitor str

Metric to monitor (default: "loss").

'loss'
Example

callback = EarlyStoppingCallback(patience=5000, min_delta=1e-4)