Skip to content

Factory Functions

The main entry points for creating and managing world models.

For docstring-derived details, see Factory API (Autogenerated).

create_world_model

Create a world model from a preset, alias, or saved path.

from worldflux import create_world_model

model = create_world_model(
    model="dreamerv3:size12m",
    obs_shape=(3, 64, 64),
    action_dim=4,
    device="cpu",
)

Parameters

Parameter Type Default Description
model str required Model preset, alias, or path to a saved model
obs_shape tuple[int, ...] None Observation shape (required for new model creation)
action_dim int None Action dimension (required for new model creation)
device str "cpu" Device to place model on
api_version str "v3" API compatibility mode
**kwargs Config overrides for the selected model family

Model Specifiers

Presets (type:size)

model = create_world_model("dreamerv3:size12m", ...)
model = create_world_model("tdmpc2:5m", ...)

Aliases

# DreamerV3 aliases
model = create_world_model("dreamer", ...)        # dreamerv3:size12m
model = create_world_model("dreamer-small", ...)  # dreamerv3:size12m
model = create_world_model("dreamer-medium", ...) # dreamerv3:size50m
model = create_world_model("dreamer-large", ...)  # dreamerv3:size200m

# TD-MPC2 aliases
model = create_world_model("tdmpc", ...)         # tdmpc2:5m
model = create_world_model("tdmpc-small", ...)   # tdmpc2:5m
model = create_world_model("tdmpc-medium", ...)  # tdmpc2:48m
model = create_world_model("tdmpc-large", ...)   # tdmpc2:317m

Load from Path

model = create_world_model("./my_saved_model")
model = create_world_model("/path/to/checkpoint")

Config Overrides

model = create_world_model(
    "tdmpc2:19m",
    obs_shape=(39,),
    action_dim=6,
    hidden_dim=768,      # Valid TDMPC2Config field
    num_q_networks=7,    # Valid TDMPC2Config field
)

Use config-field names that exist on the selected config class.

Returns

A world model instance implementing the WorldModel protocol.


list_models

List available model presets.

from worldflux import list_models

# Simple list
models = list_models()

# With descriptions
models = list_models(verbose=True)

Parameters

Parameter Type Default Description
verbose bool False Return detailed metadata instead of only names
maturity str \| None None Optional maturity filter (reference, experimental, skeleton)

Returns

  • verbose=False: list[str] of model names
  • verbose=True: dict[str, dict] with catalog metadata

Reference Presets

DreamerV3

Preset Approx Params deter_dim stoch_discrete stoch_classes
dreamerv3:size12m ~12M 2048 16 16
dreamerv3:size25m ~25M 4096 32 16
dreamerv3:size50m ~50M 4096 32 32
dreamerv3:size100m ~100M 8192 32 32
dreamerv3:size200m ~200M 8192 32 32

TD-MPC2

Preset Approx Params latent_dim hidden_dim num_q_networks
tdmpc2:5m ~5M 256 256 5
tdmpc2:19m ~19M 512 512 5
tdmpc2:48m ~48M 512 1024 5
tdmpc2:317m ~317M 1024 2048 5