Training API
Complete training infrastructure for world models.
Trainer supports two execution modes:
- local: native
torchmodels returned bycreate_world_model(..., backend="native_torch") - delegated:
OfficialBackendHandlevalues returned by non-native backend requests
Distributed training guidance for the stable public surface:
DDPTraineris the current multi-GPU trainer entry point- planned FSDP support is not part of the stable
worldflux.trainingAPI in v0.1.x
For docstring-derived details, see Training API (Autogenerated).
Quick Start
from worldflux import create_world_model
from worldflux.training import train, ReplayBuffer
model = create_world_model("dreamer:ci", obs_shape=(3, 64, 64), action_dim=2)
buffer = ReplayBuffer.load("data.npz")
trained_model = train(model, buffer, total_steps=50_000)
Delegated training uses Trainer.submit() instead of train():
from worldflux import create_world_model
from worldflux.training import Trainer, TrainingConfig
handle = create_world_model(
"dreamerv3:official_xl",
backend="official_dreamerv3_jax_subprocess",
device="cuda",
)
trainer = Trainer(
handle,
TrainingConfig(
backend="official_dreamerv3_jax_subprocess",
backend_profile="official_xl",
device="cuda",
),
)
job = trainer.submit()
train
One-liner training function.
from worldflux.training import train
trained_model = train(
model,
buffer,
total_steps=50_000,
batch_size=16,
learning_rate=3e-4,
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
model | WorldModel | required | Model to train |
buffer | ReplayBuffer or provider | required | Training data |
total_steps | int | 100_000 | Training steps |
batch_size | int | 16 | Batch size |
sequence_length | int | 50 | Sequence length |
learning_rate | float | 3e-4 | Learning rate |
Returns
Trained model (nn.Module)
TrainingConfig
Configuration for training.
from worldflux.training import TrainingConfig
config = TrainingConfig(
# Duration
total_steps=100_000,
# Batch settings
batch_size=16,
sequence_length=50,
# Optimizer
learning_rate=3e-4,
weight_decay=0.0,
grad_clip=100.0,
# Logging
log_interval=1000,
)
All Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
total_steps | int | 100_000 | Total training steps |
batch_size | int | 16 | Batch size |
sequence_length | int | 50 | Sequence length for BPTT |
learning_rate | float | 3e-4 | Adam learning rate |
weight_decay | float | 0.0 | L2 regularization |
grad_clip | float | 100.0 | Gradient clipping norm |
warmup_steps | int | 0 | Learning rate warmup |
log_interval | int | 100 | Steps between logging |
eval_interval | int | 1000 | Steps between evaluation |
save_interval | int | 10000 | Steps between checkpoints |
device | str | "auto" | Training device |
seed | int | 42 | Random seed |
Trainer
Full training control with callbacks.
from worldflux.training import Trainer, TrainingConfig
config = TrainingConfig(total_steps=100_000)
trainer = Trainer(model, config)
trained_model = trainer.train(buffer)
When model is an OfficialBackendHandle, Trainer switches to delegated mode.
In delegated mode, local loop methods such as train() are unavailable; use
submit(), status(), logs(), and cancel() instead.
Provider protocols:
- Legacy:
sample(batch_size, seq_len=None, device="cpu") -> Batch - V2:
sample(BatchRequest(batch_size, seq_len, device)) -> Batch
Methods
train
trained_model = trainer.train(buffer)
submit / status / logs / cancel
job = trainer.submit()
status = trainer.status(job)
lines = list(trainer.logs(job))
trainer.cancel(job)
add_callback
from worldflux.training.callbacks import CheckpointCallback
trainer.add_callback(CheckpointCallback(output_dir="./checkpoints"))
ReplayBuffer
Storage for trajectory data.
Creation
from worldflux.training import ReplayBuffer
buffer = ReplayBuffer(
capacity=100_000,
obs_shape=(4,),
action_dim=2,
)
Adding Data
buffer.add_episode(
obs=np.array([...]), # [T, *obs_shape]
actions=np.array([...]), # [T, action_dim]
rewards=np.array([...]), # [T]
dones=np.array([...]), # [T]
)
Sampling
batch = buffer.sample(
batch_size=16,
seq_len=50,
device="cuda",
)
# batch.obs - [B, T, *obs_shape]
# batch.actions - [B, T, action_dim]
# batch.rewards - [B, T]
# batch.terminations - [B, T]
Save/Load
buffer.save("buffer.npz")
buffer = ReplayBuffer.load("buffer.npz")
Required .npz Schema
ReplayBuffer.load() expects a serialized buffer with these keys:
| Key | Shape | Notes |
|---|---|---|
obs | [N, *obs_shape] | Observation rows |
actions | [N, action_dim] | Action rows |
rewards | [N] | Reward per transition |
dones | [N] | Episode termination flags (0/1) |
obs_shape | [len(obs_shape)] | Stored observation shape metadata |
action_dim | scalar | Stored action-dimension metadata |
capacity | scalar | Replay buffer capacity used when saving |
If your pipeline uses terminations, store the same array under dones when
writing the .npz artifact for ReplayBuffer.load().
import numpy as np
# Example: valid ReplayBuffer.load() artifact
np.savez(
"buffer.npz",
obs=obs.astype(np.float32), # [N, *obs_shape]
actions=actions.astype(np.float32), # [N, action_dim]
rewards=rewards.astype(np.float32), # [N]
dones=terminations.astype(np.float32), # [N]
obs_shape=np.array(obs.shape[1:], dtype=np.int64),
action_dim=np.array(actions.shape[-1], dtype=np.int64),
capacity=np.array(max(len(obs), 1), dtype=np.int64),
)
Properties
| Property | Type | Description |
|---|---|---|
num_episodes | int | Number of complete episodes stored |
__len__ | int | Total number of transitions stored |
Callbacks
ProgressCallback
Progress bar with metrics.
from worldflux.training.callbacks import ProgressCallback
trainer.add_callback(ProgressCallback())
LoggingCallback
TensorBoard/console logging.
from worldflux.training.callbacks import LoggingCallback
trainer.add_callback(LoggingCallback(
log_interval=100,
use_wandb=False,
))
CheckpointCallback
Periodic model saving.
from worldflux.training.callbacks import CheckpointCallback
trainer.add_callback(CheckpointCallback(
output_dir="./checkpoints",
save_interval=10_000,
max_checkpoints=3,
))
EarlyStoppingCallback
Stop on plateau.
from worldflux.training.callbacks import EarlyStoppingCallback
trainer.add_callback(EarlyStoppingCallback(
monitor="loss",
patience=5000,
min_delta=1e-4,
))
Custom Callbacks
from worldflux.training.callbacks import Callback
class MyCallback(Callback):
def on_step_end(self, step, losses):
if step % 1000 == 0:
print(f"Step {step}: loss={losses['loss']:.4f}")
trainer.add_callback(MyCallback())
Data Utilities
create_random_buffer
Create buffer with random data for testing.
from worldflux.training.data import create_random_buffer
buffer = create_random_buffer(
capacity=10_000,
obs_shape=(4,),
action_dim=2,
num_episodes=100,
episode_length=100,
seed=42,
)
Complete Example
from worldflux import create_world_model
from worldflux.training import Trainer, TrainingConfig, ReplayBuffer
from worldflux.training.callbacks import (
ProgressCallback,
LoggingCallback,
CheckpointCallback,
)
# Create model
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(3, 64, 64),
action_dim=6,
device="cuda",
)
# Load data
buffer = ReplayBuffer.load("trajectories.npz")
# Configure
config = TrainingConfig(
total_steps=100_000,
batch_size=16,
sequence_length=50,
learning_rate=3e-4,
log_interval=1000,
)
# Setup trainer
trainer = Trainer(model, config)
trainer.add_callback(ProgressCallback())
trainer.add_callback(LoggingCallback(log_interval=100))
trainer.add_callback(CheckpointCallback(output_dir="./ckpt", save_interval=10000))
# Train
trained_model = trainer.train(buffer)
# Save final
trained_model.save_pretrained("./final_model")