Skip to content

WorldModel Base Class

All world models implement the WorldModel base class.

For docstring-derived details, see Protocol API (Autogenerated).

Interface (v3 default)

class WorldModel(nn.Module, ABC):
    def encode(self, obs: Tensor | dict[str, Tensor] | WorldModelInput, deterministic: bool = False) -> State: ...
    def transition(
        self,
        state: State,
        action: ActionPayload | Tensor | None,
        conditions: ConditionPayload | None = None,
        deterministic: bool = False,
    ) -> State: ...
    def update(
        self,
        state: State,
        action: ActionPayload | Tensor | None,
        obs: Tensor | dict[str, Tensor] | WorldModelInput,
        conditions: ConditionPayload | None = None,
    ) -> State: ...
    def decode(self, state: State, conditions: ConditionPayload | None = None) -> ModelOutput: ...
    def rollout(
        self,
        initial_state: State,
        action_sequence: ActionSequence | ActionPayload | Tensor | None,
        conditions: ConditionPayload | None = None,
        deterministic: bool = False,
        mode: str = "autoregressive",
    ) -> Trajectory: ...
    def loss(self, batch: Batch) -> LossOutput: ...

Legacy calls (encode(obs), transition(state, action_tensor)) still work in v0.2.

rollout(..., mode=...) is deprecated in v0.2 and removed in v0.3. Use planner strategies (worldflux.planners) for re-planning and tree-search behaviors.

create_world_model() now defaults to api_version="v3". Use api_version="v0.2" only for explicit migration bridging.

Key Payload Types

ActionPayload(kind, tensor=None, tokens=None, latent=None, text=None, extras={})
ConditionPayload(text_condition=None, goal=None, spatial=None, camera_pose=None, extras={})
WorldModelInput(observations, context, action, conditions)

Planner payload metadata:

  • canonical key: extras["wf.planner.horizon"] (int >= 1)
  • legacy key: extras["wf.planner.sequence"] (deprecated in v0.2, removed in v0.3)
  • helper APIs: normalize_planned_action(...), first_action(...)

Condition extras in strict mode:

  • keys must be namespaced (wf.<domain>.<name>)
  • keys must be declared by each model's io_contract().condition_spec.allowed_extra_keys or io_contract().condition_extras_schema
  • declared schema entries are validated for dtype/shape in strict mode

Action contract in strict mode:

  • io_contract().action_union_spec may declare multiple valid action variants
  • payloads must match at least one declared action variant

ModelOutput

ModelOutput now uses predictions as canonical field, with preds kept as a compatibility alias.

  • predictions: model predictions (obs, reward, continue, q_values, ...)
  • state: optional state
  • uncertainty: optional uncertainty tensor
  • aux: optional metadata

Capabilities

Use capability helpers to branch safely:

if model.supports_reward:
    ...
if model.supports_continue:
    ...

Batch Format

Batch supports both legacy and universal forms in v0.2:

batch = Batch(
    # legacy
    obs=...,
    actions=...,
    rewards=...,
    terminations=...,
    # universal
    inputs={...},
    targets={...},
    conditions={...},
    lengths={...},  # optional per-field variable lengths
    masks={...},    # optional per-field variable masks
    extras={...},
)

Variable-length validation can be declared in contract via io_contract().sequence_field_spec.

Trajectory

trajectory.states    # list[State]
trajectory.rewards   # Tensor[T, B] | None
trajectory.continues # Tensor[T, B] | None

Serialization Contract

save_pretrained(path) writes:

  • config.json
  • model.pt
  • worldflux_meta.json

worldflux_meta.json includes compatibility fields:

  • save_format_version
  • worldflux_version
  • api_version
  • model_type
  • contract_fingerprint
  • created_at_utc