Skip to main content

WorldModel Base Class

All world models implement the WorldModel base class.

For docstring-derived details, see Protocol API (Autogenerated).

Interface (v3 default)

class WorldModel(nn.Module, ABC):
def encode(self, obs: Tensor | dict[str, Tensor] | WorldModelInput, deterministic: bool = False) -> State: ...
def transition(
self,
state: State,
action: ActionPayload | Tensor | None,
conditions: ConditionPayload | None = None,
deterministic: bool = False,
) -> State: ...
def update(
self,
state: State,
action: ActionPayload | Tensor | None,
obs: Tensor | dict[str, Tensor] | WorldModelInput,
conditions: ConditionPayload | None = None,
) -> State: ...
def decode(self, state: State, conditions: ConditionPayload | None = None) -> ModelOutput: ...
def rollout(
self,
initial_state: State,
action_sequence: ActionSequence | ActionPayload | Tensor | None,
conditions: ConditionPayload | None = None,
deterministic: bool = False,
mode: str = "autoregressive",
) -> Trajectory: ...
def loss(self, batch: Batch) -> LossOutput: ...

Legacy calls (encode(obs), transition(state, action_tensor)) still work in v0.2.

rollout(..., mode=...) is deprecated in v0.2 and removed in v0.3. Use planner strategies (worldflux.planners) for re-planning and tree-search behaviors.

create_world_model() now defaults to api_version="v3". Use api_version="v0.2" only for explicit migration bridging.

Key Payload Types

ActionPayload(kind, tensor=None, tokens=None, latent=None, text=None, extras={})
ConditionPayload(text_condition=None, goal=None, spatial=None, camera_pose=None, extras={})
WorldModelInput(observations, context, action, conditions)

Planner payload metadata:

  • canonical key: extras["wf.planner.horizon"] (int >= 1)
  • legacy key: extras["wf.planner.sequence"] (deprecated in v0.2, removed in v0.3)
  • helper APIs: normalize_planned_action(...)
  • v0.2 legacy: first_action(...) (available from worldflux.core.payloads)

Condition extras in strict mode:

  • keys must be namespaced (wf.<domain>.<name>)
  • keys must be declared by each model's io_contract().condition_spec.allowed_extra_keys or io_contract().condition_extras_schema
  • declared schema entries are validated for dtype/shape in strict mode

Action contract in strict mode:

  • io_contract().action_union_spec may declare multiple valid action variants
  • payloads must match at least one declared action variant

ModelOutput

ModelOutput now uses predictions as canonical field, with preds kept as a compatibility alias.

  • predictions: model predictions (obs, reward, continue, q_values, ...)
  • state: optional state
  • uncertainty: optional uncertainty tensor
  • aux: optional metadata

Capabilities

Use capability helpers to branch safely:

if model.supports_reward:
...
if model.supports_continue:
...

Batch Format

Batch supports both legacy and universal forms in v0.2:

batch = Batch(
# legacy
obs=...,
actions=...,
rewards=...,
terminations=...,
# universal
inputs={...},
targets={...},
conditions={...},
lengths={...}, # optional per-field variable lengths
masks={...}, # optional per-field variable masks
extras={...},
)

Variable-length validation can be declared in contract via io_contract().sequence_field_spec.

Trajectory

trajectory.states    # list[State]
trajectory.rewards # Tensor[T, B] | None
trajectory.continues # Tensor[T, B] | None

Serialization Contract

save_pretrained(path) writes:

  • config.json
  • model.pt
  • worldflux_meta.json

worldflux_meta.json includes compatibility fields:

  • save_format_version
  • worldflux_version
  • api_version
  • model_type
  • contract_fingerprint
  • created_at_utc