Quick Start
Minimal execution path for WorldFlux.
1) Create a Model
from worldflux import create_world_model
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(3, 64, 64),
action_dim=4,
)
2) Run Imagination
import torch
obs = torch.randn(1, 3, 64, 64)
state = model.encode(obs)
actions = torch.randn(15, 1, 4) # [horizon, batch, action_dim]
trajectory = model.rollout(state, actions)
print(trajectory.rewards.shape) # [15, 1]
print(trajectory.continues.shape) # [15, 1]
3) Save and Reload
model.save_pretrained("./my_model")
reloaded = create_world_model("./my_model")
4) Choosing a Model in worldflux init
worldflux init now recommends a model and asks you to choose explicitly between:
dreamer:citdmpc2:ci
How recommendation works:
- If your environment is image/spatial (
3,64,64style), it recommendsdreamer:ci - If your environment is vector/state (
39style), it recommendstdmpc2:ci
How to choose:
- Start with the recommended model unless you have a clear reason to switch
- Switch to
tdmpc2:cifor compact vector tasks when you want a lighter baseline - Switch to
dreamer:cifor image-heavy tasks when latent imagination quality is the priority
If you are unsure about input/output dimensions, read: Observation Shape and Action Dim.
5) Comparing DreamerV3 and TD-MPC2
The two reference-family implementations are designed for different observation spaces. Here is how to create both and compare rollout outputs on the same task:
import torch
from worldflux import create_world_model
# --- Image-based environment (e.g. Atari, DMControl vision) ---
obs_shape_img = (3, 64, 64)
action_dim = 4
dreamer = create_world_model(
"dreamerv3:size12m",
obs_shape=obs_shape_img,
action_dim=action_dim,
device="cpu",
)
# --- Vector-based environment (e.g. MuJoCo state, robotics) ---
obs_shape_vec = (39,)
action_dim_vec = 6
tdmpc = create_world_model(
"tdmpc2:5m",
obs_shape=obs_shape_vec,
action_dim=action_dim_vec,
device="cpu",
)
# Rollout with DreamerV3
obs_img = torch.randn(1, *obs_shape_img)
state_d = dreamer.encode(obs_img)
actions_d = torch.randn(10, 1, action_dim)
traj_d = dreamer.rollout(state_d, actions_d)
# Rollout with TD-MPC2
obs_vec = torch.randn(1, *obs_shape_vec)
state_t = tdmpc.encode(obs_vec)
actions_t = torch.randn(10, 1, action_dim_vec)
traj_t = tdmpc.rollout(state_t, actions_t)
print(f"DreamerV3 rewards shape: {traj_d.rewards.shape}") # [10, 1]
print(f"TD-MPC2 rewards shape: {traj_t.rewards.shape}") # [10, 1]
Use DreamerV3 for image/pixel observations with its RSSM latent dynamics. Use TD-MPC2 for low-dimensional vector states where fast MLP planning shines. See Unified Comparison for a contract-level comparison guide.
6) Configuration Customization
create_world_model accepts keyword arguments that override the preset defaults.
You can also inspect or modify a config object before model creation with get_config.
Override via create_world_model
from worldflux import create_world_model
# Override DreamerV3 architecture parameters
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(3, 64, 64),
action_dim=4,
deter_dim=1024, # smaller deterministic state
hidden_dim=128, # narrower MLPs
kl_free=0.5, # tighter KL constraint
learning_rate=1e-4, # lower learning rate
device="cpu",
)
# Override TD-MPC2 planning parameters
model = create_world_model(
"tdmpc2:5m",
obs_shape=(39,),
action_dim=6,
horizon=10, # longer planning horizon
num_samples=1024, # more CEM samples
num_elites=128, # more elite trajectories
num_q_networks=3, # smaller Q ensemble
device="cpu",
)
Inspect config before creation with get_config
from worldflux import get_config
config = get_config("dreamerv3:size12m")
print(config.deter_dim) # 2048
print(config.stoch_discrete) # 16
print(config.hidden_dim) # 256
config = get_config("tdmpc2:5m", obs_shape=(39,), action_dim=6)
print(config.latent_dim) # 256
print(config.horizon) # 5
print(config.num_samples) # 512
List all available models
from worldflux import list_models
# Simple list of supported MVP preset identifiers
print(list_models())
# ['dreamer:ci', 'dreamerv3:size12m', 'dreamerv3:size25m', ...]
# Supported MVP catalog with descriptions and parameter counts
catalog = list_models(verbose=True)
for name, info in catalog.items():
print(f"{name:25s} {info['params']:>8s} {info['description']}")
# Include advanced proof-oriented presets explicitly
public_catalog = list_models(verbose=True, surface="public")
print("Advanced available:", "dreamerv3:official_xl" in public_catalog)
# Filter experimental families explicitly
experimental = list_models(maturity="experimental", surface="all")
print(experimental)
# ['jepa:base', 'vjepa2:ci', 'vjepa2:tiny', 'vjepa2:base', ...]
7) Batch Processing
WorldFlux models natively support batched observations. Pass a batch dimension as the first axis of your tensors:
import torch
from worldflux import create_world_model
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(3, 64, 64),
action_dim=4,
device="cpu",
)
batch_size = 8
# Encode a batch of observations
obs_batch = torch.randn(batch_size, 3, 64, 64)
states = model.encode(obs_batch)
print(f"Batch size in state: {states.batch_size}") # 8
# Rollout with batched actions: [horizon, batch, action_dim]
horizon = 15
actions = torch.randn(horizon, batch_size, 4)
trajectory = model.rollout(states, actions)
print(f"Rewards shape: {trajectory.rewards.shape}") # [15, 8]
print(f"Continues shape: {trajectory.continues.shape}") # [15, 8]
print(f"Actions shape: {trajectory.actions.shape}") # [15, 8, 4]
# Access individual time-step states
for t in range(trajectory.horizon):
step_state = trajectory.states[t + 1] # states[0] is initial
print(f" Step {t}: batch_size={step_state.batch_size}")
Observations use [batch, ...] layout.
Action sequences for rollout use [horizon, batch, action_dim] layout.
Rewards and continues are returned as [horizon, batch].
8) Error Handling
Common errors and how to fix them.
Shape mismatch on encode
import torch
from worldflux import create_world_model
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(3, 64, 64),
action_dim=4,
)
# Wrong: observation shape does not match obs_shape
try:
bad_obs = torch.randn(1, 84, 84, 3) # HWC instead of CHW
model.encode(bad_obs)
except Exception as e:
print(f"Error: {e}")
# Fix: transpose to [batch, C, H, W]
good_obs = torch.randn(1, 3, 64, 64)
state = model.encode(good_obs)
Action dimension mismatch in rollout
# Wrong: action_dim=4 but passing 6-dim actions
try:
state = model.encode(torch.randn(1, 3, 64, 64))
bad_actions = torch.randn(10, 1, 6) # 6 != expected 4
model.rollout(state, bad_actions)
except Exception as e:
print(f"Error: {e}")
# Fix: match action_dim to model configuration
good_actions = torch.randn(10, 1, 4)
trajectory = model.rollout(state, good_actions)
CNN encoder requires 3D obs_shape
from worldflux import create_world_model
# Wrong: vector obs with CNN encoder (DreamerV3 defaults to CNN)
try:
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(39,), # vector, but encoder_type defaults to "cnn"
action_dim=6,
)
except Exception as e:
print(f"Error: {e}")
# Fix: use encoder_type="mlp" for vector observations
model = create_world_model(
"dreamerv3:size12m",
obs_shape=(39,),
action_dim=6,
encoder_type="mlp",
decoder_type="mlp",
)
DreamerV3 defaults to CNN encoder/decoder, which requires obs_shape with 3
dimensions (C, H, W). For vector observations, explicitly set
encoder_type="mlp" and decoder_type="mlp".
Unknown model identifier
from worldflux import create_world_model
try:
model = create_world_model("nonexistent:model")
except Exception as e:
print(f"Error: {e}")
# Fix: check available models
from worldflux import list_models
print("Available:", list_models())