Skip to main content

Training API

Complete training infrastructure for world models.

Trainer supports two execution modes:

  • local: native torch models returned by create_world_model(..., backend="native_torch")
  • delegated: OfficialBackendHandle values returned by non-native backend requests

Distributed training guidance for the stable public surface:

  • DDPTrainer is the current multi-GPU trainer entry point
  • planned FSDP support is not part of the stable worldflux.training API in v0.1.x

For docstring-derived details, see Training API (Autogenerated).

Quick Start

from worldflux import create_world_model
from worldflux.training import train, ReplayBuffer

model = create_world_model("dreamer:ci", obs_shape=(3, 64, 64), action_dim=2)
buffer = ReplayBuffer.load("data.npz")

trained_model = train(model, buffer, total_steps=50_000)

Delegated training uses Trainer.submit() instead of train():

from worldflux import create_world_model
from worldflux.training import Trainer, TrainingConfig

handle = create_world_model(
"dreamerv3:official_xl",
backend="official_dreamerv3_jax_subprocess",
device="cuda",
)
trainer = Trainer(
handle,
TrainingConfig(
backend="official_dreamerv3_jax_subprocess",
backend_profile="official_xl",
device="cuda",
),
)
job = trainer.submit()

train

One-liner training function.

from worldflux.training import train

trained_model = train(
model,
buffer,
total_steps=50_000,
batch_size=16,
learning_rate=3e-4,
)

Parameters

ParameterTypeDefaultDescription
modelWorldModelrequiredModel to train
bufferReplayBuffer or providerrequiredTraining data
total_stepsint100_000Training steps
batch_sizeint16Batch size
sequence_lengthint50Sequence length
learning_ratefloat3e-4Learning rate

Returns

Trained model (nn.Module)


TrainingConfig

Configuration for training.

from worldflux.training import TrainingConfig

config = TrainingConfig(
# Duration
total_steps=100_000,

# Batch settings
batch_size=16,
sequence_length=50,

# Optimizer
learning_rate=3e-4,
weight_decay=0.0,
grad_clip=100.0,

# Logging
log_interval=1000,
)

All Parameters

ParameterTypeDefaultDescription
total_stepsint100_000Total training steps
batch_sizeint16Batch size
sequence_lengthint50Sequence length for BPTT
learning_ratefloat3e-4Adam learning rate
weight_decayfloat0.0L2 regularization
grad_clipfloat100.0Gradient clipping norm
warmup_stepsint0Learning rate warmup
log_intervalint100Steps between logging
eval_intervalint1000Steps between evaluation
save_intervalint10000Steps between checkpoints
devicestr"auto"Training device
seedint42Random seed

Trainer

Full training control with callbacks.

from worldflux.training import Trainer, TrainingConfig

config = TrainingConfig(total_steps=100_000)
trainer = Trainer(model, config)

trained_model = trainer.train(buffer)

When model is an OfficialBackendHandle, Trainer switches to delegated mode. In delegated mode, local loop methods such as train() are unavailable; use submit(), status(), logs(), and cancel() instead.

Provider protocols:

  • Legacy: sample(batch_size, seq_len=None, device="cpu") -> Batch
  • V2: sample(BatchRequest(batch_size, seq_len, device)) -> Batch

Methods

train

trained_model = trainer.train(buffer)

submit / status / logs / cancel

job = trainer.submit()
status = trainer.status(job)
lines = list(trainer.logs(job))
trainer.cancel(job)

add_callback

from worldflux.training.callbacks import CheckpointCallback

trainer.add_callback(CheckpointCallback(output_dir="./checkpoints"))

ReplayBuffer

Storage for trajectory data.

Creation

from worldflux.training import ReplayBuffer

buffer = ReplayBuffer(
capacity=100_000,
obs_shape=(4,),
action_dim=2,
)

Adding Data

buffer.add_episode(
obs=np.array([...]), # [T, *obs_shape]
actions=np.array([...]), # [T, action_dim]
rewards=np.array([...]), # [T]
dones=np.array([...]), # [T]
)

Sampling

batch = buffer.sample(
batch_size=16,
seq_len=50,
device="cuda",
)

# batch.obs - [B, T, *obs_shape]
# batch.actions - [B, T, action_dim]
# batch.rewards - [B, T]
# batch.terminations - [B, T]

Save/Load

buffer.save("buffer.npz")
buffer = ReplayBuffer.load("buffer.npz")

Required .npz Schema

ReplayBuffer.load() expects a serialized buffer with these keys:

KeyShapeNotes
obs[N, *obs_shape]Observation rows
actions[N, action_dim]Action rows
rewards[N]Reward per transition
dones[N]Episode termination flags (0/1)
obs_shape[len(obs_shape)]Stored observation shape metadata
action_dimscalarStored action-dimension metadata
capacityscalarReplay buffer capacity used when saving

If your pipeline uses terminations, store the same array under dones when writing the .npz artifact for ReplayBuffer.load().

import numpy as np

# Example: valid ReplayBuffer.load() artifact
np.savez(
"buffer.npz",
obs=obs.astype(np.float32), # [N, *obs_shape]
actions=actions.astype(np.float32), # [N, action_dim]
rewards=rewards.astype(np.float32), # [N]
dones=terminations.astype(np.float32), # [N]
obs_shape=np.array(obs.shape[1:], dtype=np.int64),
action_dim=np.array(actions.shape[-1], dtype=np.int64),
capacity=np.array(max(len(obs), 1), dtype=np.int64),
)

Properties

PropertyTypeDescription
num_episodesintNumber of complete episodes stored
__len__intTotal number of transitions stored

Callbacks

ProgressCallback

Progress bar with metrics.

from worldflux.training.callbacks import ProgressCallback

trainer.add_callback(ProgressCallback())

LoggingCallback

TensorBoard/console logging.

from worldflux.training.callbacks import LoggingCallback

trainer.add_callback(LoggingCallback(
log_interval=100,
use_wandb=False,
))

CheckpointCallback

Periodic model saving.

from worldflux.training.callbacks import CheckpointCallback

trainer.add_callback(CheckpointCallback(
output_dir="./checkpoints",
save_interval=10_000,
max_checkpoints=3,
))

EarlyStoppingCallback

Stop on plateau.

from worldflux.training.callbacks import EarlyStoppingCallback

trainer.add_callback(EarlyStoppingCallback(
monitor="loss",
patience=5000,
min_delta=1e-4,
))

Custom Callbacks

from worldflux.training.callbacks import Callback

class MyCallback(Callback):
def on_step_end(self, step, losses):
if step % 1000 == 0:
print(f"Step {step}: loss={losses['loss']:.4f}")

trainer.add_callback(MyCallback())

Data Utilities

create_random_buffer

Create buffer with random data for testing.

from worldflux.training.data import create_random_buffer

buffer = create_random_buffer(
capacity=10_000,
obs_shape=(4,),
action_dim=2,
num_episodes=100,
episode_length=100,
seed=42,
)

Complete Example

from worldflux import create_world_model
from worldflux.training import Trainer, TrainingConfig, ReplayBuffer
from worldflux.training.callbacks import (
ProgressCallback,
LoggingCallback,
CheckpointCallback,
)

# Create model
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(3, 64, 64),
action_dim=6,
device="cuda",
)

# Load data
buffer = ReplayBuffer.load("trajectories.npz")

# Configure
config = TrainingConfig(
total_steps=100_000,
batch_size=16,
sequence_length=50,
learning_rate=3e-4,
log_interval=1000,
)

# Setup trainer
trainer = Trainer(model, config)
trainer.add_callback(ProgressCallback())
trainer.add_callback(LoggingCallback(log_interval=100))
trainer.add_callback(CheckpointCallback(output_dir="./ckpt", save_interval=10000))

# Train
trained_model = trainer.train(buffer)

# Save final
trained_model.save_pretrained("./final_model")