Architectures

Temporal U-Net

class campd.architectures.diffusion.temporal_unet.TemporalUnetCfg

Bases: BaseModel

n_support_points: int
state_dim: int
unet_input_dim: int
dim_mults: Tuple[int, ...]
time_emb_dim: int
enable_conditioning: bool
conditioning_embed_dim: int
attention_num_heads: int
attention_dim_head: int
add_time_emb_to_conditioning: bool
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class campd.architectures.diffusion.temporal_unet.TemporalUnet

Bases: ReverseDiffusionNetwork

conditioning_key = 'all'
__init__(config)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

config (TemporalUnetCfg)

classmethod from_config(config)
Parameters:

config (TemporalUnetCfg | dict)

forward(x, t, embedded_context_batch)

x : [ batch x horizon x state_dim ] t : [ batch ] (int or float usually, but here Tensor) embedded_context_batch: EmbeddedContext, assumed to have a key called “all” that contains all the embeddings stacked on the second dimension

Return type:

Tensor

Parameters:

Reverse Diffusion Base

class campd.architectures.diffusion.base.ReverseDiffusionNetwork

Bases: Module, ABC

Abstract base class for neural networks that learn the reverse diffusion process.

abstractmethod forward(x, t, embedded_context_batch)

Forward pass of the reserve diffusion network.

Parameters:
  • x (Tensor) – The batched noisy input data.

  • t (Tensor) – The batched diffusion timestep(s).

  • embedded_context_batch (EmbeddedContext) – The embedded context for the batch.

Returns:

The batched predicted noise or denoised data.

Return type:

Tensor

Context Encoder

class campd.architectures.context.encoder.KeyNetModule

Bases: Protocol

__init__(*args, **kwargs)
class campd.architectures.context.encoder.ContextEncoderCfg

Bases: BaseModel

key_networks: Mapping[str, Spec[KeyNetModule]]
concat_config: Dict[str, List[str] | None] | None
include_start_goal: bool
model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class campd.architectures.context.encoder.ContextEncoder

Bases: Module

Encodes TrajectoryContext into EmbeddedContext using a dedicated network for each context key.

__init__(config)
Parameters:

config (ContextEncoderCfg) – ContextEncoderCfg object or dictionary.

key_networks: nn.ModuleDict[str, KeyNetModule]
property context_keys: list[str]
property context_dims: dict[str, int]
classmethod from_config(config)

Factory method to create ContextEncoder from config.

Return type:

ContextEncoder

Parameters:

config (ContextEncoderCfg | dict)

forward(context)
Parameters:

context (TrajectoryContext) – TrajectoryContext to encode.

Return type:

EmbeddedContext

Returns:

EmbeddedContext containing the encoded context.

Layers

Core neural network layers and building blocks for CAMPD architectures.

Includes standard MLP implementations, Temporal U-Net residual blocks, attention mechanisms, and various normalizations/activations.

campd.architectures.layers.layers.ACTIVATIONS = {'elu': <class 'torch.nn.modules.activation.ELU'>, 'identity': <class 'torch.nn.modules.linear.Identity'>, 'leaky_relu': <class 'torch.nn.modules.activation.LeakyReLU'>, 'mish': <class 'torch.nn.modules.activation.Mish'>, 'prelu': <class 'torch.nn.modules.activation.PReLU'>, 'relu': <class 'torch.nn.modules.activation.ReLU'>, 'sigmoid': <class 'torch.nn.modules.activation.Sigmoid'>, 'softplus': <class 'torch.nn.modules.activation.Softplus'>, 'tanh': <class 'torch.nn.modules.activation.Tanh'>}

Dictionary mapping activation function names to their PyTorch module classes.

class campd.architectures.layers.layers.MLP1DCfg

Bases: BaseModel

Configuration for a 1D Multi-Layer Perceptron.

in_dim: int

Input feature dimension.

out_dim: int

Output feature dimension.

hidden_dim: int

Dimension of hidden layers.

n_layers: int

Number of hidden layers.

act: Literal['relu', 'sigmoid', 'tanh', 'leaky_relu', 'elu', 'prelu', 'softplus', 'mish', 'identity']

Activation function name (e.g., ‘relu’, ‘mish’, ‘elu’).

layer_norm: bool

Whether to apply layer normalization after linear layers.

model_config: ClassVar[ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class campd.architectures.layers.layers.MLP1D

Bases: Module

A standard 1D Multi-Layer Perceptron (MLP) module.

Constructs a sequence of Linear -> [LayerNorm] -> Activation layers.

__init__(config)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:

config (MLP1DCfg)

classmethod from_config(config)

Instantiates an MLP1D from a configuration object or dictionary.

Parameters:

config (MLP1DCfg | dict) – Configuration for the MLP1D.

Returns:

An instantiated MLP1D module.

Return type:

MLP1D

forward(x)

Forward pass through the MLP.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Output tensor representing the processed features.

Return type:

torch.Tensor

class campd.architectures.layers.layers.Residual

Bases: Module

Applies a residual connection around a given function/module.

Parameters:

fn (nn.Module) – The module to wrap.

__init__(fn)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, *args, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.PreNorm

Bases: Module

Applies LayerNorm before a given function/module.

Parameters:
  • dim (int) – Feature dimension for normalization.

  • fn (nn.Module) – The module to wrap.

__init__(dim, fn)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.LayerNorm

Bases: Module

Custom LayerNorm implementation avoiding standard PyTorch constraints.

Parameters:
  • dim (int) – Feature dimension.

  • eps (float) – Small value to avoid division by zero.

__init__(dim, eps=1e-05)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.TimeEncoder

Bases: Module

Encodes time steps using sinusoidal embeddings followed by an MLP.

Parameters:
  • dim (int) – Base embedding dimension.

  • dim_out (int) – Output embedding dimension.

__init__(dim, dim_out)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.SinusoidalPosEmb

Bases: Module

Sinusoidal positional embeddings for time/position encoding.

Parameters:

dim (int) – Embedding dimension. Must be an even number.

__init__(dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.Downsample1d

Bases: Module

Downsamples a 1D sequence using a strided convolution.

Parameters:

dim (int) – Number of channels.

__init__(dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.Upsample1d

Bases: Module

Upsamples a 1D sequence using a transposed convolution.

Parameters:

dim (int) – Number of channels.

__init__(dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.Conv1dBlock

Bases: Module

A convolutional block applying Conv1d -> GroupNorm -> Mish.

Parameters:
  • inp_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel_size (int) – Size of the convolving kernel.

  • padding (int, optional) – Zero-padding added to both sides of the input.

  • n_groups (int) – Number of groups for GroupNorm. Defaults to 8.

__init__(inp_channels, out_channels, kernel_size, padding=None, n_groups=8)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.layers.ResidualTemporalBlock

Bases: Module

A residual temporal block with conditioning for diffusion models.

Parameters:
  • inp_channels (int) – Input channel dimension.

  • out_channels (int) – Output channel dimension.

  • cond_embed_dim (int) – Conditioning embedding dimension.

  • n_support_points (int) – Number of support points (sequence length).

  • kernel_size (int) – Size of the convolving kernel. Defaults to 5.

__init__(inp_channels, out_channels, cond_embed_dim, n_support_points, kernel_size=5)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, c)

x : [ batch_size x inp_channels x n_support_points ] c : [ batch_size x embed_dim ] returns: out : [ batch_size x out_channels x n_support_points ]

campd.architectures.layers.layers.group_norm_n_groups(n_channels, target_n_groups=8)

Safely computes the number of groups for GroupNorm based on channels.

Finds a valid number of groups (divisible by n_channels) close to target.

Parameters:
  • n_channels (int) – Number of channels.

  • target_n_groups (int) – Target number of groups. Defaults to 8.

Returns:

Realized number of groups for GroupNorm.

Return type:

int

Attention

class campd.architectures.layers.attention.GEGLU

Bases: Module

__init__(dim_in, dim_out)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.attention.FeedForward

Bases: Module

__init__(dim, dim_out=None, mult=4, glu=False, dropout=0.0)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

campd.architectures.layers.attention.zero_module(module)

Zero out the parameters of a module and return it.

campd.architectures.layers.attention.Normalize(in_channels)
class campd.architectures.layers.attention.CrossAttention

Bases: Module

Cross-attention implemented with PyTorch SDPA. This will use FlashAttention / mem-efficient kernels on CUDA when available.

__init__(query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, context=None, mask=None)

x: (b, n, query_dim) context: (b, m, context_dim) or None -> self-attn mask: (b, m) boolean where True means “this key position is valid”

class campd.architectures.layers.attention.BasicTransformerBlock

Bases: Module

__init__(dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, context=None, mask=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class campd.architectures.layers.attention.SpatialTransformer

Bases: Module

Transformer block for trajectory-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to trajectory

__init__(in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, context=None, mask=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Layer Utilities

campd.architectures.layers.utils.prob_mask_like(shape, prob, device)
campd.architectures.layers.utils.exists(val)
campd.architectures.layers.utils.uniq(arr)
campd.architectures.layers.utils.default(val, d)
campd.architectures.layers.utils.max_neg_value(t)
campd.architectures.layers.utils.init_(tensor)