Skip to main content

Training API Reference

Training utilities for WorldFlux world models, including the Trainer class, one-liner train function, replay buffer, and training callbacks.

Trainer has two modes:

  • local mode for native Torch models
  • delegated mode for OfficialBackendHandle jobs backed by external runtimes

Stable distributed entry points are intentionally narrow in v0.1.x:

  • DDPTrainer is the supported distributed trainer surface
  • planned FSDP support is not exported from worldflux.training
from worldflux.training import Trainer, TrainingConfig, ReplayBuffer, train
from worldflux.training.callbacks import (
Callback, LoggingCallback, CheckpointCallback, EarlyStoppingCallback,
)

train

def train(
model: nn.Module,
data: BatchProvider | Any,
total_steps: int | None = None,
batch_size: int = 16,
sequence_length: int = 50,
learning_rate: float = 3e-4,
grad_clip: float = 100.0,
output_dir: str = "./outputs",
device: str = "auto",
**kwargs: Any,
) -> nn.Module

One-liner training function for quick experimentation.

Parameters

ParameterTypeDefaultDescription
modelnn.ModulerequiredWorld model to train (must implement loss(batch)).
dataBatchProvider | AnyrequiredBatchProvider or iterable yielding Batch/dict.
total_stepsint | NoneNoneNumber of training steps. Defaults to 100_000 if None.
batch_sizeint16Batch size.
sequence_lengthint50Sequence length for trajectory sampling.
learning_ratefloat3e-4Learning rate.
grad_clipfloat100.0Gradient clipping value.
output_dirstr"./outputs"Directory for outputs.
devicestr"auto"Device to train on ("cuda", "cpu", or "auto").
**kwargsAnyAdditional TrainingConfig options.

Returns

nn.Module -- The trained model.

Example

from worldflux import create_world_model
from worldflux.training import train, ReplayBuffer

model = create_world_model("dreamerv3:size12m")
buffer = ReplayBuffer.load("data.npz")
trained_model = train(model, buffer, total_steps=50_000)

Trainer

class Trainer

HuggingFace-style trainer for WorldFlux. Provides automatic device placement, gradient clipping, checkpointing, logging (console and optional wandb), learning rate scheduling, gradient accumulation, and mixed-precision training.

Constructor

def __init__(
self,
model: WorldModel | nn.Module | OfficialBackendHandle,
config: TrainingConfig | None = None,
callbacks: list[Callback] | None = None,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
)
ParameterTypeDefaultDescription
modelWorldModel | nn.Module | OfficialBackendHandlerequiredNative model for local training, or delegated backend handle for external execution.
configTrainingConfig | NoneNoneTraining configuration. Uses defaults if None.
callbackslist[Callback] | NoneNoneAdditional callbacks for logging/checkpointing. LoggingCallback and CheckpointCallback are always included by default.
optimizerOptimizer | NoneNoneCustom optimizer. Created from config if None.
schedulerLRScheduler | NoneNoneCustom learning rate scheduler. Created from config if None.

When model is an OfficialBackendHandle, Trainer enters delegated mode and routes execution through submit(). Local loop operations such as train() remain available only for backend="native_torch".

Methods

train

def train(
self,
data: BatchProvider | Any,
num_steps: int | None = None,
resume_from: str | None = None,
) -> nn.Module

Train the model.

ParameterTypeDefaultDescription
dataBatchProvider | AnyrequiredBatchProvider or iterable yielding Batch/dict.
num_stepsint | NoneNoneNumber of steps to train. Uses config.total_steps if None.
resume_fromstr | NoneNonePath to checkpoint to resume from.

Returns: nn.Module -- The trained model.

submit / status / logs / cancel

def submit(self, *, resume_from: str | None = None) -> JobHandle
def status(self, handle: JobHandle) -> JobStatus
def logs(self, handle: JobHandle) -> Iterator[str]
def cancel(self, handle: JobHandle) -> None

Delegated execution helpers for non-native backends. submit() is only valid when the trainer was constructed from an OfficialBackendHandle.

evaluate

def evaluate(
self,
data: BatchProvider | Any,
num_batches: int = 10,
) -> dict[str, float]

Evaluate the model on data.

ParameterTypeDefaultDescription
dataBatchProvider | AnyrequiredData source for evaluation.
num_batchesint10Number of batches to evaluate.

Returns: dict[str, float] -- Dictionary of average metrics.

save_checkpoint / load_checkpoint

def save_checkpoint(self, path: str) -> None
def load_checkpoint(self, path: str) -> None

Save/load training checkpoint. Save uses atomic write pattern (write to temp file, validate, then rename) to prevent corrupted checkpoints. Load raises CheckpointError if file is missing or corrupted.

add_callback

def add_callback(self, callback: Callback) -> None

Register a callback after trainer construction.

runtime_profile

def runtime_profile(self) -> dict[str, float | None]

Return lightweight runtime profiling metrics for DX instrumentation. Keys: "ttfi_sec", "elapsed_sec", "throughput_steps_per_sec".

Example

from worldflux import create_world_model
from worldflux.training import Trainer, TrainingConfig, ReplayBuffer

model = create_world_model("dreamerv3:size12m")
buffer = ReplayBuffer.load("data.npz")
config = TrainingConfig(total_steps=50_000, batch_size=32)

trainer = Trainer(model, config)
trainer.train(buffer)

TrainingConfig

@dataclass
class TrainingConfig

Configuration for training world models.

Fields

FieldTypeDefaultDescription
total_stepsint100_000Total number of training steps.
batch_sizeint16Batch size for training.
sequence_lengthint50Sequence length for trajectory sampling.
instant_modeboolFalseEnable instant mode for fast testing.
instant_total_stepsint8Total steps when instant mode is active.
instant_batch_sizeint8Batch size when instant mode is active.
instant_sequence_lengthint10Sequence length when instant mode is active.
learning_ratefloat3e-4Learning rate for optimizer.
grad_clipfloat100.0Maximum gradient norm for clipping.
weight_decayfloat0.0Weight decay for optimizer.
warmup_stepsint0Number of warmup steps for learning rate scheduler.
log_intervalint100Interval (in steps) for logging metrics.
eval_intervalint1000Interval (in steps) for evaluation.
save_intervalint10000Interval (in steps) for saving checkpoints.
output_dirstr"./outputs"Directory for saving outputs (checkpoints, logs).
devicestr"auto"Device to train on ("cuda", "cpu", "auto").
seedint42Random seed for reproducibility.
mixed_precisionboolFalseWhether to use mixed precision training via torch.amp.
num_workersint0Number of workers for data loading.
prefetch_factorint2Number of batches to prefetch per worker.
optimizerstr"adamw"Optimizer type: "adamw", "adam", or "sgd".
schedulerstr"none"LR scheduler: "none", "linear", "cosine", or "constant".
gradient_accumulation_stepsint1Number of gradient accumulation steps. Effective batch size becomes batch_size * gradient_accumulation_steps.
auto_quality_checkboolTrueRun a smoke-level quality check after training completes.

Advanced/Internal placeholders

The following config knobs exist in the broader type surface, but they are not part of the supported MVP training path:

  • ema_decay: reserved placeholder, rejected by Trainer
  • model_overrides: reserved trainer-level patch surface, currently unsupported

Methods

save / load

def save(self, path: str | Path) -> None
def load(cls, path: str | Path) -> TrainingConfig

Save/load config to/from JSON file.

with_updates

def with_updates(self, **kwargs: Any) -> TrainingConfig

Return a new config with updated values.

resolve_device

def resolve_device(self) -> str

Resolve "auto" device to actual device ("cuda" or "cpu").

effective_total_steps / effective_batch_size / effective_sequence_length

def effective_total_steps(self) -> int
def effective_batch_size(self) -> int
def effective_sequence_length(self) -> int

Return the effective value under the current mode (normal or instant).

Example

config = TrainingConfig(total_steps=100_000, batch_size=32)
config.save("training_config.json")
loaded = TrainingConfig.load("training_config.json")

ReplayBuffer

class ReplayBuffer

Efficient trajectory storage for world model training. Stores episodes as contiguous arrays and supports efficient random sampling of trajectory segments.

warning

Not thread-safe. Concurrent calls to add_episode() from multiple threads may cause race conditions. Use a single writer thread or external synchronization.

Constructor

def __init__(
self,
capacity: int,
obs_shape: tuple[int, ...],
action_dim: int,
obs_dtype: type = np.float32,
)
ParameterTypeDefaultDescription
capacityintrequiredMaximum number of transitions to store.
obs_shapetuple[int, ...]requiredShape of observations (e.g. (3, 64, 64) for images).
action_dimintrequiredDimension of action space.
obs_dtypetypenp.float32NumPy dtype for observations.

Methods

add_episode

def add_episode(
self,
obs: np.ndarray,
actions: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray | None = None,
) -> None

Add a complete episode to the buffer.

ParameterTypeDefaultDescription
obsnp.ndarrayrequiredObservations of shape [episode_len, *obs_shape].
actionsnp.ndarrayrequiredActions of shape [episode_len, action_dim].
rewardsnp.ndarrayrequiredRewards of shape [episode_len].
donesnp.ndarray | NoneNoneDone flags of shape [episode_len]. If None, last step is marked done.

sample

def sample(
self,
batch_size: int,
seq_len: int,
device: str | torch.device = "cpu",
) -> Batch

Sample random trajectory segments.

ParameterTypeDefaultDescription
batch_sizeintrequiredNumber of trajectory segments.
seq_lenintrequiredLength of each trajectory segment.
devicestr | torch.device"cpu"Device to place tensors on.

Returns: Batch with keys obs [B, T, *obs_shape], actions [B, T, action_dim], rewards [B, T], terminations [B, T].

save / load

def save(self, path: str | Path) -> None

@classmethod
def load(cls, path: str | Path) -> ReplayBuffer

Save/load buffer to/from .npz format with schema validation.

to_parquet / from_parquet

def to_parquet(self, path: str | Path, *, compression: str = "zstd") -> None

@classmethod
def from_parquet(cls, path: str | Path) -> ReplayBuffer

Save/load buffer to/from Parquet format for cloud-native analytics pipelines. Requires optional pyarrow dependency.

from_trajectories

@classmethod
def from_trajectories(
cls,
trajectories: list[dict[str, np.ndarray]],
capacity: int | None = None,
) -> ReplayBuffer

Create buffer from a list of trajectory dictionaries with keys "obs", "actions", "rewards", "dones".

Properties

PropertyTypeDescription
num_episodesintNumber of complete episodes stored.

Example

buffer = ReplayBuffer(capacity=100_000, obs_shape=(3, 64, 64), action_dim=6)
buffer.add_episode(obs, actions, rewards, dones)
batch = buffer.sample(batch_size=32, seq_len=50)

Callbacks

Callback (Base Class)

class Callback(ABC)

Base class for training callbacks. Override any of the following hooks:

HookWhen Called
on_train_begin(trainer)At the start of training.
on_train_end(trainer)At the end of training.
on_epoch_begin(trainer)At the start of each epoch.
on_epoch_end(trainer)At the end of each epoch.
on_step_begin(trainer)Before each training step.
on_step_end(trainer)After each training step.

LoggingCallback

class LoggingCallback(Callback)

Callback for logging training metrics to console and optionally to wandb.

def __init__(
self,
log_interval: int = 100,
use_wandb: bool = False,
wandb_project: str | None = None,
wandb_run_name: str | None = None,
)
ParameterTypeDefaultDescription
log_intervalint100Steps between log outputs.
use_wandbboolFalseWhether to log to wandb.
wandb_projectstr | NoneNonewandb project name.
wandb_run_namestr | NoneNonewandb run name.

CheckpointCallback

class CheckpointCallback(Callback)

Callback for saving model checkpoints at regular intervals.

def __init__(
self,
save_interval: int = 10000,
output_dir: str = "./outputs",
save_best: bool = True,
max_checkpoints: int | None = 5,
)
ParameterTypeDefaultDescription
save_intervalint10000Steps between checkpoint saves.
output_dirstr"./outputs"Directory to save checkpoints.
save_bestboolTrueWhether to save the best model (lowest loss).
max_checkpointsint | None5Maximum number of checkpoints to keep. None for unlimited.

EarlyStoppingCallback

class EarlyStoppingCallback(Callback)

Callback for early stopping based on loss plateau detection.

def __init__(
self,
patience: int = 5000,
min_delta: float = 1e-4,
monitor: str = "loss",
)
ParameterTypeDefaultDescription
patienceint5000Number of steps to wait without improvement before stopping.
min_deltafloat1e-4Minimum improvement required to reset patience counter.
monitorstr"loss"Metric name to monitor.