WorldModel Base Class¶
All world models implement the WorldModel base class.
For docstring-derived details, see Protocol API (Autogenerated).
Interface (v3 default)¶
class WorldModel(nn.Module, ABC):
def encode(self, obs: Tensor | dict[str, Tensor] | WorldModelInput, deterministic: bool = False) -> State: ...
def transition(
self,
state: State,
action: ActionPayload | Tensor | None,
conditions: ConditionPayload | None = None,
deterministic: bool = False,
) -> State: ...
def update(
self,
state: State,
action: ActionPayload | Tensor | None,
obs: Tensor | dict[str, Tensor] | WorldModelInput,
conditions: ConditionPayload | None = None,
) -> State: ...
def decode(self, state: State, conditions: ConditionPayload | None = None) -> ModelOutput: ...
def rollout(
self,
initial_state: State,
action_sequence: ActionSequence | ActionPayload | Tensor | None,
conditions: ConditionPayload | None = None,
deterministic: bool = False,
mode: str = "autoregressive",
) -> Trajectory: ...
def loss(self, batch: Batch) -> LossOutput: ...
Legacy calls (encode(obs), transition(state, action_tensor)) still work in v0.2.
rollout(..., mode=...) is deprecated in v0.2 and removed in v0.3. Use planner strategies (worldflux.planners) for re-planning and tree-search behaviors.
create_world_model() now defaults to api_version="v3". Use api_version="v0.2" only for explicit migration bridging.
Key Payload Types¶
ActionPayload(kind, tensor=None, tokens=None, latent=None, text=None, extras={})
ConditionPayload(text_condition=None, goal=None, spatial=None, camera_pose=None, extras={})
WorldModelInput(observations, context, action, conditions)
Planner payload metadata:
- canonical key:
extras["wf.planner.horizon"](int >= 1) - legacy key:
extras["wf.planner.sequence"](deprecated in v0.2, removed in v0.3) - helper APIs:
normalize_planned_action(...),first_action(...)
Condition extras in strict mode:
- keys must be namespaced (
wf.<domain>.<name>) - keys must be declared by each model's
io_contract().condition_spec.allowed_extra_keysorio_contract().condition_extras_schema - declared schema entries are validated for dtype/shape in strict mode
Action contract in strict mode:
io_contract().action_union_specmay declare multiple valid action variants- payloads must match at least one declared action variant
ModelOutput¶
ModelOutput now uses predictions as canonical field, with preds kept as a compatibility alias.
predictions: model predictions (obs,reward,continue,q_values, ...)state: optional stateuncertainty: optional uncertainty tensoraux: optional metadata
Capabilities¶
Use capability helpers to branch safely:
Batch Format¶
Batch supports both legacy and universal forms in v0.2:
batch = Batch(
# legacy
obs=...,
actions=...,
rewards=...,
terminations=...,
# universal
inputs={...},
targets={...},
conditions={...},
lengths={...}, # optional per-field variable lengths
masks={...}, # optional per-field variable masks
extras={...},
)
Variable-length validation can be declared in contract via io_contract().sequence_field_spec.
Trajectory¶
trajectory.states # list[State]
trajectory.rewards # Tensor[T, B] | None
trajectory.continues # Tensor[T, B] | None
Serialization Contract¶
save_pretrained(path) writes:
config.jsonmodel.ptworldflux_meta.json
worldflux_meta.json includes compatibility fields:
save_format_versionworldflux_versionapi_versionmodel_typecontract_fingerprintcreated_at_utc