Add New Network Architecture
- Create network file in
mip/networks/ - Inherit from
BaseNetwork - Implement
forward(x, s, t, condition)returning(action, scalar) - Register in
network_utils.get_network() - Create config file in
examples/configs/network/
Base Network Interface
All networks in mip/networks/ inherit from BaseNetwork and implement:
python
class BaseNetwork(nn.Module):
def forward(
self,
x: torch.Tensor, # Action trajectory (B, T, action_dim)
s: torch.Tensor, # Source time (B,)
t: torch.Tensor, # Target time (B,)
condition: torch.Tensor, # Encoded observations (B, cond_dim)
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
action: Predicted action (B, T, action_dim)
scalar: Auxiliary scalar prediction (B,)
"""Dual Output Design:
- Primary output: Action prediction
- Secondary output: Scalar value (can be used for value estimation, uncertainty, etc.)
- Enhances learning signal without additional architectural complexity
Available Networks
MLP (
mip/networks/mlp.py):- Multi-layer perceptron with timestep embeddings
- Simple, fast, works well for state observations
- Configuration:
emb_dim,num_layers,dropout
Vanilla MLP (
mip/networks/mlp.py):- Even simpler MLP without special embeddings
- Baseline for ablation studies
ChiUNet (
mip/networks/chiunet.py):- U-Net architecture from Diffusion Policy (Chi et al.)
- Temporal convolutions with skip connections
- Excellent for action trajectory modeling
JannerUNet (
mip/networks/jannerunet.py):- U-Net from Decision Diffuser (Janner et al.)
- Alternative U-Net design
ChiTransformer (
mip/networks/chitfm.py):- Transformer architecture from Diffusion Policy
- Self-attention over action sequence
- Better for long horizons
SudeepDiT (
mip/networks/sudeepdit.py):- Diffusion Transformer (DiT) architecture
- State-of-the-art generative modeling
- Scalable to large models
RNN (
mip/networks/rnn.py):- LSTM/GRU-based recurrent networks
- Sequential processing of action trajectories
- Configuration:
rnn_type(LSTM or GRU)
Network Selection
Networks are instantiated via network_utils.get_network():
python
from mip.network_utils import get_network
net = get_network(
network_config=config.network,
task_config=config.task,
)The function automatically handles:
- Input/output dimension inference from task config
- Conditional dimension calculation based on encoder
- Network-specific parameter initialization