Usage Guide¶
This guide shows how to use ContextualConv1d and ContextualConv2d in your PyTorch models.
🔧 1D Example (no context)¶
from contextual_conv import ContextualConv1d
import torch
layer = ContextualConv1d(in_channels=4, out_channels=8, kernel_size=3, padding=1)
x = torch.randn(2, 4, 32)
out = layer(x) # shape: (2, 8, 32)
🧠 1D Example with context (bias only)¶
layer = ContextualConv1d(
in_channels=4,
out_channels=8,
kernel_size=3,
padding=1,
context_dim=10,
use_bias=True
)
c = torch.randn(2, 10)
out = layer(x, c) # shape: (2, 8, 32)
🧠 1D Example with context (scale only)¶
layer = ContextualConv1d(
in_channels=4,
out_channels=8,
kernel_size=3,
padding=1,
context_dim=10,
use_scale=True
)
c = torch.randn(2, 10)
out = layer(x, c) # shape: (2, 8, 32)
🔁 1D Example with FiLM (scale + bias)¶
layer = ContextualConv1d(
in_channels=4,
out_channels=8,
kernel_size=3,
padding=1,
context_dim=10,
use_scale=True,
use_bias=True
)
c = torch.randn(2, 10)
out = layer(x, c) # y = γ(c) * conv(x) + β(c)
🧠 1D with MLP for context¶
layer = ContextualConv1d(
in_channels=4,
out_channels=8,
kernel_size=3,
padding=1,
context_dim=10,
h_dim=16,
use_bias=True
)
out = layer(x, c)
🧠 Multi-layer MLP for context¶
You can pass a list to h_dim to define multiple hidden layers:
layer = ContextualConv1d(
in_channels=4,
out_channels=8,
kernel_size=3,
padding=1,
context_dim=10,
h_dim=[32, 64, 16],
use_scale=True
)
This creates the following context path:
context_dim → 32 → ReLU → 64 → ReLU → 16 → ReLU → out_channels
🖼️ 2D Example with FiLM-style context¶
from contextual_conv import ContextualConv2d
conv = ContextualConv2d(
in_channels=3,
out_channels=6,
kernel_size=3,
padding=1,
context_dim=10,
h_dim=32,
use_scale=True,
use_bias=True
)
x = torch.randn(2, 3, 16, 16)
c = torch.randn(2, 10)
out = conv(x, c) # shape: (2, 6, 16, 16)
✅ Notes¶
If
context_dimis set, the context vectorcis passed through a linear layer or MLP.The result is used as a per-output-channel scale and/or bias, broadcast over all locations.
You can enable either
use_scale,use_bias, or both.You can disable the MLP by omitting
h_dim.