Skip to main content

Configuration Reference

Configuration classes for world models. All configs are Python dataclasses with built-in validation, JSON serialization, and size-preset factories.

from worldflux.core.config import WorldModelConfig, DreamerV3Config, TDMPC2Config
note

TrainingConfig is documented separately in the Training API Reference.


WorldModelConfig

@dataclass
class WorldModelConfig

Base configuration for all world models. Subclasses (DreamerV3Config, TDMPC2Config) extend this with model-specific parameters.

Fields

FieldTypeDefaultDescription
model_typestr"base"Identifier for the model type ("dreamer", "tdmpc2", etc.).
model_namestr"unnamed"Human-readable name or size preset name.
obs_shapetuple[int, ...](3, 64, 64)Shape of observations (e.g. (3, 64, 64) for images, (39,) for vectors).
action_dimint6Dimension of the action space.
action_typestr"continuous"Type of actions: "continuous", "discrete", "token", "latent", "text", or "none".
observation_modalitiesdict[str, dict[str, Any]]{}Multi-modal observation spec. Auto-populated from obs_shape if empty.
action_specdict[str, Any]{}Normalized action specification. Auto-populated from action_type and action_dim.
latent_typeLatentTypeDETERMINISTICType of latent space: DETERMINISTIC, GAUSSIAN, CATEGORICAL, VQ, or SIMNORM.
latent_dimint256Dimension of the primary latent space.
deter_dimint256Dimension of deterministic state (RSSM models).
stoch_dimint32Dimension of stochastic state (RSSM models).
dynamics_typeDynamicsTypeMLPDynamics model architecture: RSSM, MLP, TRANSFORMER, or SSM.
hidden_dimint512Hidden dimension for MLPs and other layers.
learning_ratefloat3e-4Default learning rate for training.
grad_clipfloat100.0Gradient clipping threshold.
devicestr"cuda"Target device ("cuda", "cpu", "auto").
dtypestr"float32"Data type: "float32", "float16", or "bfloat16".

Enums

LatentType

ValueDescription
DETERMINISTICSimple deterministic latent space.
GAUSSIANGaussian latent space with mean and variance.
CATEGORICALCategorical latent space (used by DreamerV3).
VQVector-quantized latent space.
SIMNORMSimNorm latent space (used by TD-MPC2).

DynamicsType

ValueDescription
RSSMRecurrent State-Space Model (DreamerV3).
MLPSimple MLP dynamics (TD-MPC2).
TRANSFORMERTransformer-based dynamics.
SSMState-Space Model dynamics.

Methods

to_dict / from_dict

def to_dict(self) -> dict[str, Any]

@classmethod
def from_dict(cls, d: dict[str, Any]) -> WorldModelConfig

Convert to/from dictionary for serialization.

save / load

def save(self, path: str | Path) -> None

@classmethod
def load(cls, path: str | Path) -> WorldModelConfig

Save/load configuration to/from JSON file.

Example

config = WorldModelConfig(obs_shape=(84, 84, 3), action_dim=4)
config.save("config.json")
loaded = WorldModelConfig.load("config.json")

DreamerV3Config

@dataclass
class DreamerV3Config(WorldModelConfig)

DreamerV3 world model configuration. Uses an RSSM (Recurrent State-Space Model) with categorical latent variables.

Size Presets

PresetParamsdeter_dimstochhidden_dimcnn_depth
ci~0.1M644x4328
size12m~12M204816x1625648
size25m~25M409632x1651248
size50m~50M409632x3264048
size100m~100M819232x3276848
size200m~200M819232x32102448
official_xl~200-300M819232x64102464

ci is for quick validation / scaffold workflows and is not the canonical proof-grade parity preset. Dreamer parity proof uses official_xl.

DreamerV3-Specific Fields

FieldTypeDefaultDescription
stoch_discreteint32Number of categorical distributions.
stoch_classesint32Number of classes per categorical distribution.
encoder_typestr"cnn"Type of encoder: "cnn" or "mlp".
decoder_typestr"cnn"Type of decoder: "cnn" or "mlp".
cnn_depthint48Base depth multiplier for CNN encoder/decoder.
cnn_kernelstuple[int, ...](4, 4, 4, 4)Kernel sizes for CNN layers.
learning_ratefloat1e-4DreamerV3 paper uses 1e-4 for world model.
grad_clipfloat1000.0DreamerV3 paper uses 1000 for grad clip.
kl_freefloat1.0Free nats for KL divergence (prevents posterior collapse).
loss_scalesdict[str, float]see belowWeights for each loss component.
use_symlogboolTrueWhether to use symlog transformation for predictions.
use_twohotboolTrueWhether to use twohot categorical reward prediction.
reward_num_binsint255Number of bins for twohot reward prediction.
reward_bin_minfloat-20.0Minimum value for reward bins (symlog space).
reward_bin_maxfloat20.0Maximum value for reward bins (symlog space).

Actor-Critic Fields (gated by actor_critic=True):

FieldTypeDefaultDescription
actor_criticboolFalseEnable actor-critic training.
imagination_horizonint15Imagination rollout horizon.
actor_lrfloat3e-5Actor learning rate.
critic_lrfloat3e-5Critic learning rate.
gammafloat0.997Discount factor.
lambda_float0.95GAE lambda.
slow_critic_fractionfloat0.02EMA fraction for slow critic target.
actor_entropy_coeffloat3e-4Entropy regularization coefficient.
return_normalizationboolTrueWhether to normalize returns.

Default loss_scales

{
"reconstruction": 1.0,
"kl_dynamics": 0.5, # beta_dyn
"kl_representation": 0.1, # beta_rep
"reward": 1.0,
"continue": 1.0,
}

Methods

from_size

@classmethod
def from_size(cls, size: str, **kwargs: Any) -> DreamerV3Config

Create configuration from a size preset.

ParameterTypeDefaultDescription
sizestrrequiredSize preset name: "ci", "size12m", "size25m", "size50m", "size100m", "size200m", "official_xl".
**kwargsAnyOverride any preset parameters.

Returns: DreamerV3Config with the specified size preset.

Raises: ValueError if the size preset is not recognized.

Examples

# Create from size preset
config = DreamerV3Config.from_size("size12m")

# Create with custom parameters
config = DreamerV3Config(
obs_shape=(3, 64, 64),
action_dim=4,
deter_dim=1024,
stoch_discrete=16,
stoch_classes=16,
)

# Override preset values
config = DreamerV3Config.from_size("size50m", action_dim=18, cnn_depth=64)

TDMPC2Config

@dataclass
class TDMPC2Config(WorldModelConfig)

TD-MPC2 world model configuration. TD-MPC2 is an implicit world model that uses SimNorm latent space and learns value functions for planning. It is particularly effective for continuous control tasks.

Size Presets

PresetParamslatent_dimhidden_dim
ci~0.1M3232
5m~5M256256
proof_5m~5M256256
5m_legacy~5M256256
19m~19M512512
48m~48M5121024
317m~317M10242048

ci is for quick validation / scaffold workflows and is not the canonical proof-grade parity preset. TD-MPC2 parity proof uses proof_5m. 5m is the compatibility preset, and 5m_legacy is the legacy compatibility preset.

TD-MPC2-Specific Fields

FieldTypeDefaultDescription
simnorm_dimint8Dimension for SimNorm grouping. latent_dim must be divisible by this value.
num_hidden_layersint2Number of hidden layers in MLPs.
task_dimint96Dimension of task embedding for multi-task learning.
num_tasksint1Number of tasks for multi-task learning.
num_q_networksint5Number of Q-networks in the ensemble.
horizonint5Planning horizon for MPC.
num_samplesint512Number of action samples for CEM planning.
num_elitesint64Number of elite samples for CEM planning. Must not exceed num_samples.
temperaturefloat0.5Temperature for action sampling. Must be positive.
momentumfloat0.1Momentum for CEM mean update. Must be in [0, 1].
gammafloat0.99Discount factor. Must be in (0, 1].
use_decoderboolFalseWhether to use a decoder. TD-MPC2 is typically implicit (no decoder).

Methods

from_size

@classmethod
def from_size(cls, size: str, **kwargs: Any) -> TDMPC2Config

Create configuration from a size preset.

ParameterTypeDefaultDescription
sizestrrequiredSize preset name: "ci", "5m", "proof_5m", "5m_legacy", "19m", "48m", "317m".
**kwargsAnyOverride any preset parameters.

Returns: TDMPC2Config with the specified size preset.

Raises: ValueError if the size preset is not recognized.

Examples

# Create from size preset
config = TDMPC2Config.from_size("5m", obs_shape=(39,), action_dim=6)

# Create with custom parameters
config = TDMPC2Config(
obs_shape=(39,),
action_dim=6,
latent_dim=256,
hidden_dim=256,
)

# Large model with custom Q ensemble
config = TDMPC2Config.from_size("317m", num_q_networks=10, horizon=8)