State
The core representation used by all world models.
Overview
State is a lightweight container with two fields:
tensors:dict[str, Tensor]holding model-specific latent tensorsmeta:dict[str, Any]for optional metadata
from worldflux.core.state import State
Creating a State
Most users get a State via model.encode() or model.update():
state = model.encode(obs)
Accessing Tensors
State does not fix a schema. Each model defines its own tensor keys.
DreamerV3
deter: deterministic GRU statestoch: stochastic categorical samplesprior_logits/posterior_logits: logits for KL
features = torch.cat(
[state.tensors["deter"], state.tensors["stoch"].flatten(1)],
dim=-1,
)
TD-MPC2
latent: SimNorm embedding
latent = state.tensors["latent"]
JEPA
rep: encoder representation
rep = state.tensors["rep"]
Metadata
Use meta for non-tensor bookkeeping:
state.meta["latent_type"] = "simnorm"
State Operations
Device Transfer
# Move state to GPU
gpu_state = state.to("cuda")
# Move back to CPU
cpu_state = gpu_state.to("cpu")
Detach and Clone
# Detach from computation graph (e.g., for rollout targets)
detached = state.detach()
# Deep copy with independent tensors
cloned = state.clone()
Validation
# Verify all tensors have consistent batch dimension
state.validate() # raises StateError if inconsistent
Batch Size and Device Inspection
print(state.batch_size) # e.g., 32
print(state.device) # e.g., device(type='cuda', index=0)
Safe Tensor Access
# Returns None instead of KeyError if key is missing
latent = state.get("latent")
deter = state.get("deter", default=torch.zeros(1, 256))
Serialization
State supports binary serialization for checkpointing and IPC:
# Serialize to bytes
data = state.serialize(version="v1", format="binary")
# Deserialize from bytes
restored = State.deserialize(data)
Shared Memory (Zero-Copy IPC)
For multi-process training pipelines:
# Producer process
descriptor = state.to_shared_memory(namespace="my-state")
# Consumer process
attached = State.from_shared_memory(descriptor, copy=False)
# Clean up
attached.close_shared_memory(unlink=True)
Implementation
@dataclass
class State:
tensors: dict[str, Tensor] = field(default_factory=dict)
meta: dict[str, Any] = field(default_factory=dict)
API Reference
| Method / Property | Signature | Description |
|---|---|---|
get | get(key: str, default=None) -> Tensor | None | Safe tensor access by key. |
batch_size | property -> int | Batch dimension from first tensor. |
device | property -> torch.device | Device of first tensor. |
to | to(device) -> State | Move all tensors to device. |
detach | detach() -> State | Detach all tensors from computation graph. |
clone | clone() -> State | Deep copy with independent tensors. |
validate | validate() -> None | Check batch dimension consistency. Raises StateError. |
serialize | serialize(version="v1", format="binary") -> bytes | Binary serialization for checkpointing. |
deserialize | State.deserialize(data: bytes) -> State | Restore from serialized bytes. |
to_shared_memory | to_shared_memory(namespace: str) -> dict | Zero-copy IPC via shared memory. |
from_shared_memory | State.from_shared_memory(descriptor, copy=False) -> State | Attach to shared memory state. |
close_shared_memory | close_shared_memory(unlink=True) -> None | Release shared memory resources. |