Skip to content

Models API

Ready-to-train reference architectures built on the bitlogic primitives. These are plain nn.Sequential subclasses you can drop into your own training loop.

FeedForward

Canonical stack: DistributiveThermometerFlattenN × LogicDenseGroupSum. Every LogicDense layer has the same width, parametrization, LUT rank, and connection kind.

import torch
from bitlogic import FeedForward

fit_samples = torch.randn(128, 1, 28, 28)
model = FeedForward(
    fit_samples=fit_samples,
    num_classes=10,
    layer_width=64_000,
    num_layers=2,
    num_bits=8,
    lut_rank=4,
    parametrization="light",
    connections="learnable",
    forward_sampling="hard",   # extra kwargs forwarded to LogicDense
)

models

Ready-to-train reference architectures built on bitlogic primitives.

FeedForward

FeedForward(*, fit_samples: Tensor, num_classes: int, layer_width: int = 64000, num_layers: int = 2, num_bits: int = 8, lut_rank: int = 4, num_candidates: int = 8, parametrization: str = 'light', connections: str = 'learnable', tau: float = 150.0, encoder: str = 'distributive_thermometer', head: str = 'groupsum', head_wbits: int = 8, **layer_kwargs: Any)

Bases: Sequential

Encoder → Flatten → N×LogicDense → Head.

The canonical reference architecture for image classification. Every LogicDense layer has the same width, parametrization, LUT rank, and connection kind. Encoder and head are selected by string so the same builder drives the encoder and head sweeps in §4 of the paper.

Parameters:

Name Type Description Default
fit_samples Tensor

(N, C, H, W) tensor used to fit the thermometer encoder. The encoder is fit in-place at construction time.

required
num_classes int

Number of output classes (last dim of the head).

required
layer_width int

Neuron count per LogicDense layer.

64000
num_layers int

Number of LogicDense layers in the body (>= 1).

2
num_bits int

Thermometer bits per input feature.

8
lut_rank int

Inputs per LUT.

4
num_candidates int

Candidate pool size for learnable connections (-1 uses every input).

8
parametrization str

Parametrization name; see :mod:bitlogic.parametrizations.

'light'
connections str

fixed or learnable.

'learnable'
tau float

Temperature applied by the head.

150.0
encoder str

Input-encoder name. One of thermometer (uniform thresholds) or distributive_thermometer (empirical quantiles).

'distributive_thermometer'
head str

Output-head name. One of groupsum or grouped_dsp.

'groupsum'
head_wbits int

Weight bit-width used by grouped_dsp; ignored for groupsum.

8
**layer_kwargs Any

Extra kwargs forwarded to every LogicDense (e.g. forward_sampling="hard").

{}
Example

fit = torch.randn(128, 1, 28, 28) model = FeedForward(fit_samples=fit, num_classes=10)

Source code in bitlogic/models/feedforward.py
def __init__(
    self,
    *,
    fit_samples: torch.Tensor,
    num_classes: int,
    layer_width: int = 64_000,
    num_layers: int = 2,
    num_bits: int = 8,
    lut_rank: int = 4,
    num_candidates: int = 8,
    parametrization: str = "light",
    connections: str = "learnable",
    tau: float = 150.0,
    encoder: str = "distributive_thermometer",
    head: str = "groupsum",
    head_wbits: int = 8,
    **layer_kwargs: Any,
):
    if num_layers < 1:
        raise ValueError(f"num_layers must be >= 1, got {num_layers}")
    if connections == "learnable":
        layer_kwargs.setdefault("num_candidates", num_candidates)

    enc = _build_encoder(encoder, num_bits=num_bits)
    enc.fit(fit_samples)

    input_size = int(np.prod(fit_samples.shape[1:])) * num_bits
    modules: list[nn.Module] = [enc, nn.Flatten()]
    prev = input_size
    for _ in range(num_layers):
        modules.append(
            LogicDense(
                in_dim=prev,
                out_dim=layer_width,
                parametrization=parametrization,
                connections=connections,
                lut_rank=lut_rank,
                **layer_kwargs,
            )
        )
        prev = layer_width
    modules.append(
        _build_head(
            head,
            num_classes=num_classes,
            tau=tau,
            head_wbits=head_wbits,
        )
    )
    super().__init__(*modules)