Skip to main content

Protocol & Data Types Reference

Core protocol classes and data containers used throughout the WorldFlux framework.


WorldModel

class WorldModel(nn.Module, ABC)

Abstract base class for all world models in the WorldFlux framework. Inherits from torch.nn.Module and exposes a composable component architecture where each stage of the observe-predict-decode pipeline can be overridden independently.

Pipeline Components

ComponentDescription
observation_encoderEncodes raw observations into a latent State.
action_conditionerFuses action and condition information into the dynamics input.
dynamics_modelPredicts the next latent state given current state and conditioned inputs.
decoder_moduleMaps a latent state back to observable predictions (observations, rewards, continuation).
rollout_executorExecutes multi-step open-loop rollouts by chaining transition and decode.

Attributes

AttributeTypeDescription
capabilitiesset[Capability]Capability flags advertised by this model (e.g. REWARD_PRED, PLANNING).
observation_encoderObservationEncoder | NonePluggable encoder component.
action_conditionerActionConditioner | NonePluggable action/condition fusion component.
dynamics_modelDynamicsModel | NonePluggable latent dynamics component.
decoder_moduleDecoder | NonePluggable decoder component.
rollout_executorRolloutExecutor | NonePluggable rollout executor component.
composable_supportset[str]Component slot names effective in runtime execution paths for this model.

Methods

encode

def encode(
self,
obs: Tensor | dict[str, Tensor] | WorldModelInput,
deterministic: bool = False,
) -> State

Encode observations into a latent State. Delegates to the attached observation_encoder component.

ParameterTypeDefaultDescription
obsTensor | dict[str, Tensor] | WorldModelInputrequiredRaw observation tensor, a dict of named modality tensors, or a WorldModelInput.
deterministicboolFalseIf True, use deterministic encoding (e.g. posterior mean).

Returns: State -- Latent representation.

Raises: NotImplementedError if no observation_encoder is attached.

transition

def transition(
self,
state: State,
action: ActionPayload | Tensor | None,
conditions: ConditionPayload | None = None,
deterministic: bool = False,
) -> State

Predict the next latent state given current state and action. Performs a single imagination step through the dynamics model.

ParameterTypeDefaultDescription
stateStaterequiredCurrent latent state.
actionActionPayload | Tensor | NonerequiredAction to condition on. Accepts a raw tensor, an ActionPayload, or None for unconditional transition.
conditionsConditionPayload | NoneNoneOptional auxiliary condition signals (e.g. goal embeddings).
deterministicboolFalseIf True, use deterministic dynamics.

Returns: State -- Predicted next latent state.

Raises: NotImplementedError if no dynamics_model is attached.

decode

def decode(
self,
state: State,
conditions: ConditionPayload | None = None,
) -> ModelOutput

Decode a latent state into observable predictions.

ParameterTypeDefaultDescription
stateStaterequiredLatent state to decode.
conditionsConditionPayload | NoneNoneOptional auxiliary condition signals.

Returns: ModelOutput -- Contains the predictions dict and the originating state.

Raises: CapabilityError if no decoder_module is attached.

rollout

def rollout(
self,
initial_state: State,
action_sequence: ActionSequence | ActionPayload | Tensor | None,
conditions: ConditionPayload | None = None,
deterministic: bool = False,
mode: str = "autoregressive",
) -> Trajectory

Execute a multi-step open-loop rollout from an initial state.

ParameterTypeDefaultDescription
initial_stateStaterequiredStarting latent state.
action_sequenceActionSequence | ActionPayload | Tensor | NonerequiredSequence of actions to apply.
conditionsConditionPayload | NoneNoneOptional auxiliary condition signals applied at each step.
deterministicboolFalseIf True, use deterministic transitions.
modestr"autoregressive"Rollout mode. Only "autoregressive" is supported in v3.

Returns: Trajectory -- Collected states, actions, rewards, and continuation flags.

loss (abstract)

@abstractmethod
def loss(self, batch: Batch) -> LossOutput

Compute the training loss. Subclasses must implement this method.

ParameterTypeDefaultDescription
batchBatchrequiredTraining batch containing observations, actions, rewards, etc.

Returns: LossOutput -- Contains the total loss tensor, component losses, and metrics.

supports

def supports(self, capability: Capability) -> bool

Return True if the model advertises the given capability.

Convenience Properties

PropertyTypeDescription
supports_rewardboolWhether the model predicts rewards.
supports_continueboolWhether the model predicts continuation flags.
supports_planningboolWhether the model supports planning.

Example

from worldflux import create_world_model

model = create_world_model("dreamerv3:size12m", obs_shape=(3, 64, 64), action_dim=6)
state = model.encode(obs)
next_state = model.transition(state, action)
output = model.decode(next_state)

ActionPayload

@dataclass
class ActionPayload

Polymorphic action container that supports multiple control modalities.

Fields

FieldTypeDefaultDescription
kindActionKind"none"Action modality. One of "none", "continuous", "discrete", "token", "latent", "text".
tensorTensor | NoneNonePrimary tensor for continuous or discrete actions.
tokensTensor | NoneNoneToken tensor for token-based actions.
latentTensor | NoneNoneLatent tensor for latent-space actions.
textlist[str] | NoneNoneText strings for text-conditioned actions.
extrasdict[str, Any]{}Additional metadata (e.g. planner horizon).

Methods

primary

def primary(self) -> Tensor | None

Return the primary tensor representation, checking tensor, tokens, and latent in order.

validate

def validate(self, *, api_version: str = "v0.2") -> None

Validate payload consistency. Ensures only one primary representation is set and that kind="none" payloads carry no data.

Example

# Continuous action
action = ActionPayload(kind="continuous", tensor=torch.randn(6))

# Discrete action
action = ActionPayload(kind="discrete", tensor=torch.tensor([3]))

# Token-based action
action = ActionPayload(kind="token", tokens=torch.tensor([42, 7, 13]))

ConditionPayload

@dataclass
class ConditionPayload

Optional side-conditions for conditional world modeling.

Fields

FieldTypeDefaultDescription
text_conditionTensor | list[str] | NoneNoneText condition embedding or raw text strings.
goalTensor | NoneNoneGoal state tensor.
spatialTensor | NoneNoneSpatial condition tensor (e.g. map, layout).
camera_poseTensor | NoneNoneCamera pose tensor for 3D-conditioned models.
extrasdict[str, Any]{}Additional condition signals. Keys must follow namespaced format "wf.<domain>.<name>".

WorldModelInput

@dataclass
class WorldModelInput

Unified model input object wrapping observations, context, actions, and conditions.

Fields

FieldTypeDefaultDescription
observationsdict[str, Tensor]{}Named observation tensors keyed by modality name.
contextdict[str, Tensor]{}Additional context tensors.
actionActionPayload | NoneNoneAction payload for conditioned inference.
conditionsConditionPayloadConditionPayload()Side-condition payload.

ModelOutput

@dataclass
class ModelOutput

Standardized model output container returned by WorldModel.decode().

Fields

FieldTypeDefaultDescription
predictionsdict[str, Tensor]{}Predicted tensors keyed by name (e.g. "obs", "reward", "continue").
stateState | NoneNoneLatent state that produced these predictions.
uncertaintyTensor | NoneNoneOptional uncertainty estimate.
auxdict[str, Any]{}Auxiliary outputs (e.g. attention maps, intermediate activations).
prediction_specPredictionSpec | NoneNoneSpec describing expected prediction keys.
sequence_layoutSequenceLayout | NoneNoneAxis layout metadata for prediction tensors.

Example

output = model.decode(state)
obs_pred = output.predictions["obs"]
reward_pred = output.predictions.get("reward")

LossOutput

@dataclass
class LossOutput

Standardized loss container returned by WorldModel.loss().

Fields

FieldTypeDefaultDescription
lossTensorrequiredTotal scalar loss for backpropagation.
componentsdict[str, Tensor]{}Individual loss components (e.g. "reconstruction", "kl", "reward").
metricsdict[str, float]{}Scalar metrics for logging (e.g. gradient norms, latent statistics).

Example

loss_out = model.loss(batch)
loss_out.loss.backward()

# Log individual components
for name, value in loss_out.components.items():
print(f"{name}: {value.item():.4f}")

Trajectory

@dataclass
class Trajectory

Imagination rollout trajectory in latent space. Returned by WorldModel.rollout().

The trajectory maintains the invariant that len(states) == actions.shape[0] + 1, representing the initial state plus one state per action taken.

Fields

FieldTypeDefaultDescription
stateslist[State]requiredList of latent states [T+1] (initial + T steps).
actionsTensorrequiredAction tensor [T, batch, action_dim].
rewardsTensor | NoneNonePredicted rewards [T, batch].
valuesTensor | NoneNonePredicted values [T+1, batch].
continuesTensor | NoneNoneContinue probabilities [T, batch].
state_specStateSpec | NoneNoneSpec describing state tensor keys.
sequence_layoutSequenceLayout | NoneNoneAxis layout metadata.

Properties

PropertyTypeDescription
horizonintPrediction horizon (number of actions).
batch_sizeintBatch size from the first state.

Methods

to_tensor

def to_tensor(self, key: str) -> Tensor

Stack a specific state tensor key across time [T+1, batch, ...].

to

def to(self, device: torch.device) -> Trajectory

Move all tensors to the specified device.

detach

def detach(self) -> Trajectory

Detach all tensors from the computation graph.

Example

trajectory = model.rollout(initial_state, action_sequence)
print(f"Horizon: {trajectory.horizon}")
print(f"Rewards shape: {trajectory.rewards.shape}")

# Stack deterministic state across time
deter_stack = trajectory.to_tensor("deter")