Config API (Autogenerated)¶
This page is generated from Python docstrings via mkdocstrings.
WorldModelConfig¶
Base configuration for all world models.
This class defines the common configuration parameters shared across all world model implementations. Subclasses (DreamerV3Config, TDMPC2Config) extend this with model-specific parameters.
Attributes:
| Name | Type | Description |
|---|---|---|
model_type | str | Identifier for the model type ("dreamer", "tdmpc2", etc.). |
model_name | str | Human-readable name or size preset name. |
obs_shape | tuple[int, ...] | Shape of observations (e.g., (3, 64, 64) for images). |
action_dim | int | Dimension of the action space. |
action_type | str | Type of actions ("continuous" or "discrete"). |
latent_type | LatentType | Type of latent space representation. |
latent_dim | int | Dimension of the primary latent space. |
deter_dim | int | Dimension of deterministic state (RSSM models). |
stoch_dim | int | Dimension of stochastic state (RSSM models). |
dynamics_type | DynamicsType | Type of dynamics model architecture. |
hidden_dim | int | Hidden dimension for MLPs and other layers. |
learning_rate | float | Default learning rate for training. |
grad_clip | float | Gradient clipping threshold. |
device | str | Target device ("cuda", "cpu", "auto"). |
dtype | str | Data type ("float32", "float16", "bfloat16"). |
Example
config = WorldModelConfig(obs_shape=(84, 84, 3), action_dim=4) config.save("config.json") loaded = WorldModelConfig.load("config.json")
to_dict() ¶
Convert configuration to dictionary for serialization.
from_dict(d) classmethod ¶
Create configuration from dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d | dict[str, Any] | Dictionary with configuration parameters. | required |
Returns:
| Type | Description |
|---|---|
WorldModelConfig | Configuration instance. |
Raises:
| Type | Description |
|---|---|
ConfigurationError | If configuration is invalid. |
save(path) ¶
Save configuration to JSON file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path | str | Path | Path to save the configuration. | required |
load(path) classmethod ¶
Load configuration from JSON file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path | str | Path | Path to the configuration file. | required |
Returns:
| Type | Description |
|---|---|
WorldModelConfig | Configuration instance (DreamerV3Config, TDMPC2Config, or base). |
Raises:
| Type | Description |
|---|---|
FileNotFoundError | If the file doesn't exist. |
JSONDecodeError | If the file is not valid JSON. |
ConfigurationError | If the configuration is invalid. |
DreamerV3Config¶
Bases: WorldModelConfig
DreamerV3 world model configuration.
DreamerV3 uses an RSSM (Recurrent State-Space Model) with categorical latent variables. This configuration supports multiple size presets matching the original paper.
Size Presets
- size12m: 12M params - deter=2048, stoch=16x16, hidden=256
- size25m: 25M params - deter=4096, stoch=32x16, hidden=512
- size50m: 50M params - deter=4096, stoch=32x32, hidden=640
- size100m: 100M params - deter=8192, stoch=32x32, hidden=768
- size200m: 200M params - deter=8192, stoch=32x32, hidden=1024
Attributes:
| Name | Type | Description |
|---|---|---|
stoch_discrete | int | Number of categorical distributions. |
stoch_classes | int | Number of classes per categorical distribution. |
encoder_type | str | Type of encoder ("cnn" or "mlp"). |
decoder_type | str | Type of decoder ("cnn" or "mlp"). |
cnn_depth | int | Base depth multiplier for CNN encoder/decoder. |
cnn_kernels | tuple[int, ...] | Kernel sizes for CNN layers. |
kl_free | float | Free nats for KL divergence (prevents posterior collapse). |
kl_balance | float | Balance between prior and posterior in KL loss. |
loss_scales | dict[str, float] | Weights for each loss component. |
use_symlog | bool | Whether to use symlog transformation for predictions. |
Example
Create from size preset¶
config = DreamerV3Config.from_size("size12m")
Create with custom parameters¶
config = DreamerV3Config( ... obs_shape=(3, 64, 64), ... action_dim=4, ... deter_dim=1024, ... stoch_discrete=16, ... stoch_classes=16, ... )
from_dict(d) classmethod ¶
Create configuration from dictionary.
from_size(size, **kwargs) classmethod ¶
Create configuration from a size preset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size | str | Size preset name (size12m, size25m, size50m, size100m, size200m). | required |
**kwargs | Any | Override any preset parameters. | {} |
Returns:
| Type | Description |
|---|---|
DreamerV3Config | DreamerV3Config with the specified size preset. |
Raises:
| Type | Description |
|---|---|
ValueError | If the size preset is not recognized. |
TDMPC2Config¶
Bases: WorldModelConfig
TD-MPC2 world model configuration.
TD-MPC2 is an implicit world model that uses SimNorm latent space and learns value functions for planning. It's particularly effective for continuous control tasks.
Size Presets
- 5m: 5M params - latent=256, hidden=256
- 19m: 19M params - latent=512, hidden=512
- 48m: 48M params - latent=512, hidden=1024
- 317m: 317M params - latent=1024, hidden=2048
Attributes:
| Name | Type | Description |
|---|---|---|
simnorm_dim | int | Dimension for SimNorm grouping (latent_dim must be divisible). |
num_hidden_layers | int | Number of hidden layers in MLPs. |
task_dim | int | Dimension of task embedding for multi-task learning. |
num_tasks | int | Number of tasks for multi-task learning. |
num_q_networks | int | Number of Q-networks in the ensemble. |
horizon | int | Planning horizon for MPC. |
num_samples | int | Number of action samples for planning. |
num_elites | int | Number of elite samples for CEM planning. |
temperature | float | Temperature for action sampling. |
momentum | float | Momentum for CEM mean update. |
use_decoder | bool | Whether to use a decoder (TD-MPC2 is typically implicit). |
Example
Create from size preset¶
config = TDMPC2Config.from_size("5m", obs_shape=(39,), action_dim=6)
Create with custom parameters¶
config = TDMPC2Config( ... obs_shape=(39,), ... action_dim=6, ... latent_dim=256, ... hidden_dim=256, ... )
from_dict(d) classmethod ¶
Create configuration from dictionary.
from_size(size, **kwargs) classmethod ¶
Create configuration from a size preset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
size | str | Size preset name (5m, 19m, 48m, 317m). | required |
**kwargs | Any | Override any preset parameters. | {} |
Returns:
| Type | Description |
|---|---|
TDMPC2Config | TDMPC2Config with the specified size preset. |
Raises:
| Type | Description |
|---|---|
ValueError | If the size preset is not recognized. |
TrainingConfig¶
TrainingConfig lives in the training API reference: worldflux.training.config.TrainingConfig.