Skip to content

Parametrizations API

Parametrizations are the "what is the LUT" half of a LogicDense layer. Each one owns (1) the shape of the per-neuron weight tensor, (2) a forward basis that maps inputs and weights to LUT outputs, and (3) a get_luts method that discretizes weights into {0, 1} truth tables for hardware export.

Pick a parametrization either by string through LogicDense:

from bitlogic.layers import LogicDense

layer = LogicDense(
    in_dim=784, out_dim=4000,
    parametrization="warp", lut_rank=4, temperature=1.0,
)

...or instantiate the concrete class directly:

from bitlogic.parametrizations import WarpLUT
param = WarpLUT(lut_rank=4, temperature=1.0)

Factory

setup_parametrization

setup_parametrization(name: str, lut_rank: int, **kwargs: Any) -> LUTParametrization

Build a parametrization by name (lowercase string).

Extra kwargs are forwarded to the constructor, so e.g.::

setup_parametrization("light", lut_rank=4, forward_sampling="hard",
                      weight_init="random")

Raises KeyError for unknown names.

list_parametrizations

list_parametrizations() -> list[str]

Base class

LUTParametrization

LUTParametrization(lut_rank: int, forward_sampling: str = 'soft', temperature: float = 1.0, weight_init: str = 'random', residual_probability: float = 0.951, anchor_init: bool = True)

Bases: Module, ABC

Abstract base class for LUT parametrizations.

A parametrization owns (1) the shape of the per-neuron weight tensor, (2) the training-only forward basis that maps inputs and weights to LUT outputs (_forward_train), and (3) a get_lut method that discretizes weights into {0, 1} truth tables for hardware export and for the shared eager-eval lookup. The public forward is concrete on the base and dispatches between the two.

LUT indexing convention: every parametrization returns get_lut(weight) in MSB ordering — LUT entry k corresponds to the input pattern where input j occupies bit rank-1-j of k (i.e. input 0 is the MSB of the address). This matches the Kronecker indicator basis used by :meth:forward in eval mode and by :class:~bitlogic.parametrizations.LightLUT during training.

Class attributes

LUT_INFERENCE_SUPPORTED: Whether :func:get_lut returns a well-defined binary truth table (output at input pattern k matches LUT entry k). False for Walsh-Hadamard / linear bases where get_lut is a thresholded coefficient, not a round-tripping truth table.

Parameters:

Name Type Description Default
lut_rank int

Number of inputs per neuron. Must be one of {2, 4, 6} — basis code is hand-unrolled for these ranks.

required
forward_sampling str

How logits are sampled during training. Supported values depend on the concrete parametrization (commonly "soft", "hard", "gumbel_soft", "gumbel_hard").

'soft'
temperature float

Sigmoid / softmax temperature — lower values give sharper, more discrete outputs.

1.0
weight_init str

Either "random" (default — standard-Gaussian init) or "residual" (identity-function initialization biased toward pass-through).

'random'
residual_probability float

Probability that a residual-initialized LUT entry evaluates to its identity target. Sets the magnitude of the initial logits.

0.951
anchor_init bool

If True, the identity-function weight is stashed in a non-trainable buffer and the learnable parameter starts at zero (so weight decay pulls toward identity).

True

Raises:

Type Description
ValueError

If lut_rank is not in {2, 4, 6}, or if weight_init is not one of "residual" / "random".

Source code in bitlogic/parametrizations/base.py
def __init__(
    self,
    lut_rank: int,
    forward_sampling: str = "soft",
    temperature: float = 1.0,
    weight_init: str = "random",
    residual_probability: float = 0.951,
    anchor_init: bool = True,
):
    super().__init__()
    if lut_rank not in VALID_RANKS:
        raise ValueError(f"lut_rank must be one of {VALID_RANKS}, got {lut_rank}")
    if weight_init not in ("residual", "random"):
        raise ValueError(f"weight_init must be 'residual' or 'random', got {weight_init!r}")

    self.lut_rank: int = int(lut_rank)
    self.lut_entries: int = 1 << self.lut_rank
    self.forward_sampling: str = forward_sampling
    self.temperature: float = float(temperature)
    self.weight_init: str = weight_init
    self.residual_probability: float = float(residual_probability)
    self.anchor_init: bool = bool(anchor_init)

    # Set by ``init_weights`` when residual+anchor_init is active.
    self.register_buffer("residual_anchor", None, persistent=True)

init_weights abstractmethod

init_weights(num_neurons: int, device: device | str | None) -> Tensor

Return the initial weight tensor for a layer of num_neurons.

Parameters:

Name Type Description Default
num_neurons int

Number of output neurons in the owning layer.

required
device device | str | None

Target device for the returned tensor.

required

Returns:

Type Description
Tensor

Tensor of shape (num_neurons, num_basis) where num_basis

Tensor

is parametrization-specific (2**lut_rank for Light / Warp /

Tensor

DWN, 16 for DiffLogic, degree-specific for PolyLUT, etc.).

Source code in bitlogic/parametrizations/base.py
@abstractmethod
def init_weights(self, num_neurons: int, device: torch.device | str | None) -> torch.Tensor:
    """Return the initial weight tensor for a layer of ``num_neurons``.

    Args:
        num_neurons: Number of output neurons in the owning layer.
        device: Target device for the returned tensor.

    Returns:
        Tensor of shape ``(num_neurons, num_basis)`` where ``num_basis``
        is parametrization-specific (``2**lut_rank`` for Light / Warp /
        DWN, ``16`` for DiffLogic, degree-specific for PolyLUT, etc.).
    """

forward

forward(x: Tensor, weight: Tensor, training: bool, contraction: str = 'n,bn->bn') -> Tensor

Evaluate LUT outputs on gathered inputs.

Dispatches to :meth:_forward_train when training is True. Otherwise contracts the discrete truth table from :meth:get_lut with the Kronecker indicator basis — for binary x, this is the same as indexing get_lut(weight)[neuron, addr(x)]. Subclasses override :meth:_forward_train, not :meth:forward.

Parameters:

Name Type Description Default
x Tensor

Gathered input tensor of shape (batch, lut_rank, num_neurons).

required
weight Tensor

Per-neuron weight tensor returned by :meth:init_weights.

required
training bool

Whether to use the stochastic / soft forward (True) or the discrete eval path (False).

required
contraction str

Einsum pattern for the per-neuron reduction (typically "n,bn->bn").

'n,bn->bn'

Returns:

Type Description
Tensor

Tensor of shape (batch, num_neurons).

Source code in bitlogic/parametrizations/base.py
def forward(
    self,
    x: torch.Tensor,
    weight: torch.Tensor,
    training: bool,
    contraction: str = "n,bn->bn",
) -> torch.Tensor:
    """Evaluate LUT outputs on gathered inputs.

    Dispatches to :meth:`_forward_train` when ``training`` is ``True``.
    Otherwise contracts the discrete truth table from :meth:`get_lut` with
    the Kronecker indicator basis — for binary ``x``, this is the same as
    indexing ``get_lut(weight)[neuron, addr(x)]``. Subclasses override
    :meth:`_forward_train`, not :meth:`forward`.

    Args:
        x: Gathered input tensor of shape
            ``(batch, lut_rank, num_neurons)``.
        weight: Per-neuron weight tensor returned by
            :meth:`init_weights`.
        training: Whether to use the stochastic / soft forward
            (``True``) or the discrete eval path (``False``).
        contraction: Einsum pattern for the per-neuron reduction
            (typically ``"n,bn->bn"``).

    Returns:
        Tensor of shape ``(batch, num_neurons)``.
    """
    if training:
        return self._forward_train(x, weight, contraction)
    luts = self.get_lut(weight).to(x.dtype)
    return weighted_indicator_basis_sum(x, luts, contraction, self.lut_rank)

effective_weight

effective_weight(weight: Tensor) -> Tensor

Add the residual anchor (if any) to the trainable weight.

Source code in bitlogic/parametrizations/base.py
def effective_weight(self, weight: torch.Tensor) -> torch.Tensor:
    """Add the residual anchor (if any) to the trainable weight."""
    if self.residual_anchor is None:
        return weight
    return weight + self.residual_anchor

update_temperature

update_temperature(temperature: float) -> None

Scheduler hook: set the current softmax / sigmoid temperature.

Source code in bitlogic/parametrizations/base.py
def update_temperature(self, temperature: float) -> None:
    """Scheduler hook: set the current softmax / sigmoid temperature."""
    self.temperature = float(temperature)

Concrete parametrizations

LightLUT

LightLUT(lut_rank: int, forward_sampling: str = 'soft', temperature: float = 1.0, weight_init: str = 'random', residual_probability: float = 0.951, anchor_init: bool = True)

Bases: LUTParametrization

LUT parametrization with Kronecker indicator basis.

forward_sampling controls how logits are mapped to LUT values: * "soft" — plain sigmoid. Legacy ProbabilisticParam. * "hard" — sigmoid + STE hard. Legacy HybridParam (approx). * "gumbel_soft" / "gumbel_hard" — Gumbel-sigmoid variants.

Source code in bitlogic/parametrizations/light.py
def __init__(
    self,
    lut_rank: int,
    forward_sampling: str = "soft",
    temperature: float = 1.0,
    weight_init: str = "random",
    residual_probability: float = 0.951,
    anchor_init: bool = True,
):
    if forward_sampling not in _VALID_SAMPLING:
        raise ValueError(
            f"forward_sampling must be one of {_VALID_SAMPLING}, got {forward_sampling!r}"
        )
    super().__init__(
        lut_rank,
        forward_sampling=forward_sampling,
        temperature=temperature,
        weight_init=weight_init,
        residual_probability=residual_probability,
        anchor_init=anchor_init,
    )

WarpLUT

WarpLUT(lut_rank: int, forward_sampling: str = 'soft', temperature: float = 1.0, weight_init: str = 'random', residual_probability: float = 0.951, anchor_init: bool = True)

Bases: LUTParametrization

Walsh–Hadamard-basis LUT parametrization.

Inputs are mapped to {-1, +1} and contracted with Walsh basis vectors to obtain logits z. forward_sampling controls how the logits are squashed during training: * "soft" — plain σ(z / τ) (deterministic; default). * "hard"σ(z / τ) with a Bernoulli-sampled straight- through round to {0, 1}. Matches torchlogix-extended's sigmoid(..., hard=True). * "gumbel_soft" — Gumbel-sigmoid σ((z + logistic_noise)/τ), matching Gerlach et al. (2025) Sec 3's Gumbel reparameterization. * "gumbel_hard" — same as "gumbel_soft" with a hard straight-through round to {0, 1}.

Eval mode (training=False) collapses to the discrete truth-table lookup (get_lut(weight)[neuron, addr(x)]) so eager model.eval() matches :class:PackedLogicNet and the emitted HDL bit-for-bit. Supports lut_rank{2, 4, 6}. Weight shape: (num_neurons, 2**lut_rank). See :class:~bitlogic.parametrizations.LUTParametrization for the inherited constructor arguments.

Source code in bitlogic/parametrizations/warp.py
def __init__(
    self,
    lut_rank: int,
    forward_sampling: str = "soft",
    temperature: float = 1.0,
    weight_init: str = "random",
    residual_probability: float = 0.951,
    anchor_init: bool = True,
):
    if forward_sampling not in _VALID_SAMPLING:
        raise ValueError(
            f"forward_sampling must be one of {_VALID_SAMPLING}, got {forward_sampling!r}"
        )
    super().__init__(
        lut_rank,
        forward_sampling=forward_sampling,
        temperature=temperature,
        weight_init=weight_init,
        residual_probability=residual_probability,
        anchor_init=anchor_init,
    )

LinearLUT

LinearLUT(lut_rank: int, forward_sampling: str = 'soft', temperature: float = 1.0, weight_init: str = 'random', residual_probability: float = 0.951, anchor_init: bool = True)

Bases: LUTParametrization

Affine + sigmoid parametrization: y_n = σ(W_n @ x_n + b_n).

The simplest parametrization — mostly a sanity-check primitive. Weight shape is (num_neurons, lut_rank + 1) packing the lut_rank affine weights and a bias in the trailing column. See :class:~bitlogic.parametrizations.LUTParametrization for constructor arguments.

Source code in bitlogic/parametrizations/base.py
def __init__(
    self,
    lut_rank: int,
    forward_sampling: str = "soft",
    temperature: float = 1.0,
    weight_init: str = "random",
    residual_probability: float = 0.951,
    anchor_init: bool = True,
):
    super().__init__()
    if lut_rank not in VALID_RANKS:
        raise ValueError(f"lut_rank must be one of {VALID_RANKS}, got {lut_rank}")
    if weight_init not in ("residual", "random"):
        raise ValueError(f"weight_init must be 'residual' or 'random', got {weight_init!r}")

    self.lut_rank: int = int(lut_rank)
    self.lut_entries: int = 1 << self.lut_rank
    self.forward_sampling: str = forward_sampling
    self.temperature: float = float(temperature)
    self.weight_init: str = weight_init
    self.residual_probability: float = float(residual_probability)
    self.anchor_init: bool = bool(anchor_init)

    # Set by ``init_weights`` when residual+anchor_init is active.
    self.register_buffer("residual_anchor", None, persistent=True)

PolyLUT

PolyLUT(lut_rank: int, degree: int = _DEFAULT_DEGREE, **kwargs: Any)

Bases: LUTParametrization

Multivariate-monomial LUT parametrization up to total degree D.

y = σ(Σ_a w_a · Π_j x_j^{a_j}). Number of monomials is C(lut_rank + D, D) (including the constant term); weight shape is (num_neurons, num_monomials).

Parameters:

Name Type Description Default
lut_rank int

Number of inputs per neuron ({2, 4, 6}).

required
degree int

Maximum total degree of the monomial basis.

_DEFAULT_DEGREE
**kwargs Any

Forwarded to :class:~bitlogic.parametrizations.LUTParametrization.

{}
Source code in bitlogic/parametrizations/polylut.py
def __init__(
    self,
    lut_rank: int,
    degree: int = _DEFAULT_DEGREE,
    **kwargs: Any,
):
    super().__init__(lut_rank, **kwargs)
    self.degree = int(degree)
    exps = _monomial_exponents(self.lut_rank, self.degree)
    self.num_monomials = len(exps)
    self.register_buffer(
        "exponent_matrix",
        torch.tensor(exps, dtype=torch.float32),
    )
    # Integer twin used by the hot-path gather in :meth:`_monomials` — keeps
    # ``pow_stack`` indices on the right dtype without a per-step cast.
    self.register_buffer(
        "exponent_matrix_long",
        torch.tensor(exps, dtype=torch.long),
    )
    # Per-input max exponent is a fixed Python-int so the forward loop
    # doesn't pay a GPU→CPU sync per step.
    self._max_exp_per_input: tuple[int, ...] = tuple(
        max((row[j] for row in exps), default=0) for j in range(self.lut_rank)
    )

NeuralLUT

NeuralLUT(lut_rank: int, hidden_width: int = _DEFAULT_HIDDEN_WIDTH, depth: int = _DEFAULT_DEPTH, activation: str = _DEFAULT_ACTIVATION, **kwargs: Any)

Bases: LUTParametrization

Tiny-MLP-per-neuron LUT parametrization.

Each neuron is a small MLP of shape lut_rank → hidden_width → ... → 1. All neuron MLPs share the same architecture and are evaluated in parallel via bmm. Residual init is not supported — weight_init is effectively always random here.

Parameters:

Name Type Description Default
lut_rank int

Number of inputs per neuron.

required
hidden_width int

Width of the hidden layers.

_DEFAULT_HIDDEN_WIDTH
depth int

Number of linear layers (depth - 1 hidden layers + 1 output layer).

_DEFAULT_DEPTH
activation str

Hidden-activation name — one of "relu", "sigmoid", "leakyrelu", "tanh".

_DEFAULT_ACTIVATION
**kwargs Any

Forwarded to :class:~bitlogic.parametrizations.LUTParametrization.

{}

Raises:

Type Description
ValueError

If activation is not one of the supported names.

Source code in bitlogic/parametrizations/neurallut.py
def __init__(
    self,
    lut_rank: int,
    hidden_width: int = _DEFAULT_HIDDEN_WIDTH,
    depth: int = _DEFAULT_DEPTH,
    activation: str = _DEFAULT_ACTIVATION,
    **kwargs: Any,
):
    super().__init__(lut_rank, **kwargs)
    self.hidden_width = int(hidden_width)
    self.depth = int(depth)
    if activation not in _ACTIVATIONS:
        raise ValueError(f"activation must be one of {list(_ACTIVATIONS)}, got {activation!r}")
    self.activation_name = activation
    self.activation = _ACTIVATIONS[activation]

    self._layer_shapes: list[tuple[int, int]] = self._compute_layer_shapes()
    self._num_weights = sum(w * n + n for (w, n) in self._layer_shapes)

DwnLUT

DwnLUT(lut_rank: int, alpha: float | None = None, efd: bool = True, **kwargs: Any)

Bases: LUTParametrization

Binarized-input LUT with Extended Finite Difference (EFD) gradient.

Forward: binarize inputs at 0.5, then do a direct LUT lookup. Backward: EFD per Bacellar et al. 2025 Sec 3.1 — Hamming-weighted sum over all 2^n LUT entries (default). Setting efd=False falls back to the single-bit-flip finite-difference approximation (legacy code path; closer to what WNN-style papers did before DWN).

Weight shape: (num_neurons, 2**lut_rank); sigmoid(weight) is the LUT value.

Parameters:

Name Type Description Default
lut_rank int

Number of inputs per neuron.

required
alpha float | None

(E)FD backward scale. Defaults to 0.5 · 0.75**(lut_rank - 1) which decays with rank.

None
efd bool

If True (default), use the Hamming-weighted EFD sum. If False, use the simpler Hamming-1 FD — useful for ablations against the paper's reported improvement.

True
**kwargs Any

Forwarded to :class:~bitlogic.parametrizations.LUTParametrization.

{}
Source code in bitlogic/parametrizations/dwn.py
def __init__(
    self,
    lut_rank: int,
    alpha: float | None = None,
    efd: bool = True,
    **kwargs: Any,
):
    super().__init__(lut_rank, **kwargs)
    if alpha is None:
        alpha = _ALPHA_BASE * (_ALPHA_DECAY ** (self.lut_rank - 1))
    self.register_buffer("alpha", torch.tensor(float(alpha)))
    self.efd = bool(efd)
    if self.efd:
        # 96 KB at rank 6; not persistent — rebuilt from ``lut_rank`` on load.
        self.register_buffer(
            "_efd_coeff", _build_efd_coefficients(self.lut_rank), persistent=False
        )
    else:
        self._efd_coeff = None

DiffLogicLUT

DiffLogicLUT(lut_rank: int = 2, **kwargs: Any)

Bases: LUTParametrization

DiffLogic-style softmax over the 16 two-input Boolean functions.

Rank 2 only. Each neuron has a (16,) logit vector; the forward computes Σ_f softmax(weights)_f · f(a, b). The 16-way sum collapses algebraically into four coefficient terms (constant, linear-a, linear-b, a·b) — implemented as 4 einsums rather than materializing all 16 op values.

Parameters:

Name Type Description Default
lut_rank int

Fixed at 2; other ranks would require enumerating 2^(2^rank) truth tables and are not supported.

2
**kwargs Any

Forwarded to :class:~bitlogic.parametrizations.LUTParametrization.

{}

Raises:

Type Description
ValueError

If lut_rank != 2.

Source code in bitlogic/parametrizations/difflogic.py
def __init__(self, lut_rank: int = 2, **kwargs: Any):
    super().__init__(lut_rank, **kwargs)
    if self.lut_rank != 2:
        raise ValueError(
            f"DiffLogicLUT only supports lut_rank == 2 (got {self.lut_rank}). "
            f"Higher ranks have 2^(2^rank) truth tables."
        )