Skip to content

Configuration Guide

This page explains how to configure models and training using the current implementation.

For full, code-generated parameter details, use:

Import Paths

Use one of these import styles:

from worldflux import DreamerV3Config, TDMPC2Config
from worldflux.training import TrainingConfig
from worldflux.core.config import DreamerV3Config, TDMPC2Config
from worldflux.training.config import TrainingConfig

Model Configs

Use size presets, then override only what you need:

from worldflux import DreamerV3Config, TDMPC2Config

dreamer_cfg = DreamerV3Config.from_size(
    "size50m",
    obs_shape=(3, 64, 64),
    action_dim=18,
)

tdmpc_cfg = TDMPC2Config.from_size(
    "19m",
    obs_shape=(39,),
    action_dim=6,
)

Commonly Tuned DreamerV3 Fields

  • deter_dim
  • stoch_discrete
  • stoch_classes
  • hidden_dim
  • cnn_depth
  • kl_free
  • kl_balance

Commonly Tuned TD-MPC2 Fields

  • latent_dim
  • hidden_dim
  • num_hidden_layers
  • num_q_networks
  • horizon
  • num_samples
  • num_elites

Factory Overrides

You can pass config overrides directly to create_world_model(...):

from worldflux import create_world_model

model = create_world_model(
    "dreamerv3:size12m",
    obs_shape=(3, 64, 64),
    action_dim=4,
    hidden_dim=320,
    stoch_discrete=16,
    stoch_classes=16,
)

If a field name is not supported by the target config class, model creation fails with a configuration error.

Training Configuration

from worldflux.training import TrainingConfig

train_cfg = TrainingConfig(
    total_steps=100_000,
    batch_size=16,
    sequence_length=50,
    learning_rate=3e-4,
    grad_clip=100.0,
    device="auto",
    mixed_precision=False,
)

Useful fields beyond the basic loop:

  • weight_decay
  • warmup_steps
  • log_interval
  • eval_interval
  • save_interval
  • gradient_accumulation_steps
  • optimizer
  • scheduler
  • ema_decay

Save and Load Configs

from worldflux import WorldModelConfig
from worldflux.training import TrainingConfig

# Model config from a saved model directory
model_cfg = WorldModelConfig.load("./my_model")

# Training config JSON round-trip
train_cfg = TrainingConfig(total_steps=10_000)
train_cfg.save("training_config.json")
loaded_train_cfg = TrainingConfig.load("training_config.json")

Environment Notes

  • CUDA_VISIBLE_DEVICES: use to select visible GPUs before running training.
  • Model device defaults are implementation-defined by API:
  • create_world_model(..., device="cpu") default is "cpu".
  • TrainingConfig(device="auto") resolves to "cuda" if available, otherwise "cpu".