Skip to content

Connections API

Connections are the "which inputs feed each neuron" half of a LogicDense layer. A connections module takes an input of shape (batch, in_dim) and returns a gathered tensor of shape (batch, lut_rank, out_dim) that the parametrization consumes.

Two variants:

  • fixed — buffered routing indices; a single gather per forward. No trainable parameters.
  • learnable — softmax over a per-slot candidate pool (or all inputs with num_candidates=-1); straight-through argmax on the forward pass.

Usage via LogicDense:

LogicDense(
    in_dim=6272, out_dim=64000,
    parametrization="light", lut_rank=4,
    connections="learnable", num_candidates=8, init_method="random-unique",
)

Factory

setup_connections

setup_connections(kind: str, in_dim: int, out_dim: int, lut_rank: int, device: device | str | None = None, **kwargs: Any) -> Connections

Build a dense connection module of the given kind.

Parameters:

Name Type Description Default
kind str

One of "fixed", "random" (alias of "fixed"), "learnable".

required
in_dim int

Number of input features.

required
out_dim int

Number of output neurons.

required
lut_rank int

Inputs per neuron.

required
device device | str | None

Optional target device for buffers and parameters.

None
**kwargs Any

Extra keyword arguments forwarded to the selected class constructor (e.g. init_method, num_candidates, temperature, num_groups, group_bias).

{}

Returns:

Name Type Description
A Connections

class:Connections module ready to use inside

Connections

class:~bitlogic.layers.LogicDense.

Raises:

Type Description
ValueError

If kind is not one of the recognized names.

Base class

Connections

Connections(lut_rank: int = 2, device: device | str | None = None, init_method: str = 'random-unique', **_: Any)

Bases: Module, ABC

Abstract base for input-routing strategies.

A connections module maps an input tensor of shape (batch, in_dim) to a gathered tensor of shape (batch, lut_rank, out_dim) that a :class:~bitlogic.parametrizations.LUTParametrization can consume.

Parameters:

Name Type Description Default
lut_rank int

Number of inputs per output neuron.

2
device device | str | None

Optional target device for buffers and parameters.

None
init_method str

Index-initialization strategy — one of "random-unique" (default), "random", "group-biased". Subclasses decide which are meaningful.

'random-unique'
Source code in bitlogic/connections/base.py
def __init__(
    self,
    lut_rank: int = 2,
    device: torch.device | str | None = None,
    init_method: str = "random-unique",
    **_: Any,
):
    super().__init__()
    self.lut_rank = lut_rank
    self.device = device
    self.init_method = init_method

forward abstractmethod

forward(x: Tensor) -> Tensor

Gather lut_rank inputs per output neuron.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch, in_dim).

required

Returns:

Type Description
Tensor

Tensor of shape (batch, lut_rank, out_dim).

Source code in bitlogic/connections/base.py
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Gather ``lut_rank`` inputs per output neuron.

    Args:
        x: Input tensor of shape ``(batch, in_dim)``.

    Returns:
        Tensor of shape ``(batch, lut_rank, out_dim)``.
    """

update_temperature

update_temperature(temperature: float) -> None

Scheduler hook; learnable variants override.

Source code in bitlogic/connections/base.py
def update_temperature(self, temperature: float) -> None:
    """Scheduler hook; learnable variants override."""

Concrete strategies

FixedDenseConnections

FixedDenseConnections(in_dim: int, out_dim: int, lut_rank: int = 2, device: device | str | None = None, init_method: str = 'random-unique', num_groups: int | None = None, group_bias: float | None = None, **kwargs: Any)

Bases: Connections

Non-trainable dense routing — buffered indices, one gather per forward.

Parameters:

Name Type Description Default
in_dim int

Number of input features.

required
out_dim int

Number of output neurons.

required
lut_rank int

Inputs per neuron.

2
device device | str | None

Optional target device for the index buffer.

None
init_method str

One of "random-unique" (default), "random", "group-biased".

'random-unique'
num_groups int | None

Required for "group-biased" — partitions input and output dims into num_groups chunks.

None
group_bias float | None

Required for "group-biased" — probability of drawing an input slot from the matching group rather than anywhere.

None
**kwargs Any

Forwarded to :class:Connections.

{}
Source code in bitlogic/connections/fixed.py
def __init__(
    self,
    in_dim: int,
    out_dim: int,
    lut_rank: int = 2,
    device: torch.device | str | None = None,
    init_method: str = "random-unique",
    num_groups: int | None = None,
    group_bias: float | None = None,
    **kwargs: Any,
):
    super().__init__(lut_rank=lut_rank, device=device, init_method=init_method, **kwargs)
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.num_groups = num_groups
    self.group_bias = group_bias
    self.register_buffer("indices", self._build_indices())

LearnableDenseConnections

LearnableDenseConnections(in_dim: int, out_dim: int, lut_rank: int = 2, temperature: float = 0.001, num_candidates: int = -1, forward_sampling: str = 'soft', device: device | str | None = None, init_method: str = 'random-unique', num_groups: int | None = None, group_bias: float | None = None, **kwargs: Any)

Bases: Connections

Learnable dense routing via softmax over candidate inputs.

Hard argmax in the forward (STE), softmax-weighted gradient in backward. When num_candidates == -1 every input is a candidate for every (slot, neuron) pair and the forward reduces to a matmul. Positive num_candidates restricts the pool to a fixed subset per slot, initialized from init_method.

Parameters:

Name Type Description Default
in_dim int

Number of input features.

required
out_dim int

Number of output neurons.

required
lut_rank int

Inputs per neuron.

2
temperature float

Softmax temperature for the backward weighting.

0.001
num_candidates int

Candidate inputs per slot. -1 means use every input (fully-connected routing, matmul fast-path).

-1
forward_sampling str

"soft" / "hard" / "gumbel_soft" / "gumbel_hard". Gumbel variants add Gumbel noise before the argmax + softmax.

'soft'
device device | str | None

Optional target device for buffers / parameters.

None
init_method str

"random-unique" (default), "random", or "group-biased" (only consulted when num_candidates != -1).

'random-unique'
num_groups int | None

Required for "group-biased".

None
group_bias float | None

Required for "group-biased"; probability of keeping an in-group candidate, in [0, 1].

None
**kwargs Any

Forwarded to :class:Connections.

{}

Raises:

Type Description
ValueError

If forward_sampling or init_method is unknown.

Source code in bitlogic/connections/learnable.py
def __init__(
    self,
    in_dim: int,
    out_dim: int,
    lut_rank: int = 2,
    temperature: float = 0.001,
    num_candidates: int = -1,
    forward_sampling: str = "soft",
    device: torch.device | str | None = None,
    init_method: str = "random-unique",
    num_groups: int | None = None,
    group_bias: float | None = None,
    **kwargs: Any,
):
    if forward_sampling not in self._VALID_SAMPLING:
        raise ValueError(
            f"forward_sampling must be one of {self._VALID_SAMPLING}, got {forward_sampling!r}"
        )
    super().__init__(lut_rank=lut_rank, device=device, init_method=init_method, **kwargs)
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.temperature = float(temperature)
    self.forward_sampling = forward_sampling
    self.num_groups = num_groups
    self.group_bias = group_bias

    self._dense_identity = num_candidates == -1
    if self._dense_identity:
        num_candidates = in_dim
        indices = (
            torch.arange(in_dim, device=device)
            .view(in_dim, 1, 1)
            .expand(in_dim, lut_rank, out_dim)
            .contiguous()
        )
    else:
        assert num_candidates > 0, "num_candidates must be > 0 or -1"
        indices = self._build_indices(num_candidates)
    self.num_candidates = num_candidates
    self.register_buffer("indices", indices.to(torch.int64))

    self.weights = nn.Parameter(
        torch.empty(num_candidates, lut_rank, out_dim, dtype=torch.float32)
    )
    nn.init.xavier_uniform_(self.weights)