Skip to content

Quick Start

This guide walks through building and training a LUT network on MNIST using the top-level bitlogic API. A runnable version lives in examples/train_mnist.py — see Examples for invocation details.

graph LR
    A["Input<br/>(B, 1, 28, 28)"]
    B["DistributiveThermometer<br/>(B, 8, 28, 28)"]
    C["Flatten<br/>(B, 6272)"]
    D["LogicDense<br/>(B, 2000)"]
    E["LogicDense<br/>(B, 2000)"]
    F["GroupSum<br/>(B, 10)"]

    A --> B --> C --> D --> E --> F

1. Install

Pick a hardware extra and sync:

git clone https://github.com/aplesner/bitlogic.git
cd bitlogic
uv sync --extra cpu         # or: --extra cu128

See the Installation Guide for CUDA / pip details.

2. Imports

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from bitlogic import DistributiveThermometer, LogicDense, GroupSum

Everything the user typically needs is re-exported from the bitlogic top-level module. Factories live under bitlogic.parametrizations and bitlogic.connections for advanced use.

3. Build the model

BitLogic is just torch.nn — you assemble encoders, layers, and heads with nn.Sequential (or any other container).

model = nn.Sequential(
    DistributiveThermometer(num_bits=8),            # (B, 1, 28, 28) → (B, 8, 28, 28)
    nn.Flatten(),                                   # → (B, 6272)
    LogicDense(
        in_dim=6272, out_dim=2000,
        parametrization="light", lut_rank=4, temperature=1.0,
        connections="learnable", num_candidates=8,
    ),
    LogicDense(
        in_dim=2000, out_dim=2000,
        parametrization="light", lut_rank=4, temperature=1.0,
        connections="learnable", num_candidates=8,
    ),
    GroupSum(k=10, tau=30.0),
)

Anatomy:

  • EncoderDistributiveThermometer(num_bits=8) turns each scalar input into 8 bits using quantile thresholds fit on training data. The default encode_axis=1 keeps NCHW layout: an (N, 1, 28, 28) image becomes (N, 8, 28, 28).
  • LogicDense — each output neuron reads lut_rank=4 inputs (selected by the learnable connections module with 8 candidate inputs per slot) and evaluates a LUT through the "light" parametrization.
  • GroupSum — partitions the final 2000 features into 10 groups of 200 and sums each group, giving class logits divided by tau.

4. Fit the encoder

Thermometer encoders need one pass over training data to compute thresholds. Do this once, before training.

transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST("./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=128, shuffle=False)

# Collect enough training samples to estimate quantiles, then fit the encoder.
with torch.no_grad():
    sample_batches = []
    for imgs, _ in train_loader:
        sample_batches.append(imgs)
        if sum(b.shape[0] for b in sample_batches) >= 10000:
            break
    fit_sample = torch.cat(sample_batches, dim=0)

model[0].fit(fit_sample)   # DistributiveThermometer is the first module

5. Train

Standard PyTorch — nothing BitLogic-specific:

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

opt = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(5):
    model.train()
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        opt.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, labels)
        loss.backward()
        opt.step()

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f"epoch {epoch}: test accuracy = {100 * correct / total:.2f}%")

6. Inspect the learned LUTs

Every LogicDense layer can emit discretized truth tables and input-id routing — the inputs to an HDL export backend:

for i, mod in enumerate(model):
    if isinstance(mod, LogicDense):
        luts, ids = mod.get_luts_and_ids()
        print(f"layer {i}: luts {tuple(luts.shape)}  ids {tuple(ids.shape)}")

luts is (out_dim, 2**lut_rank) of {0, 1} entries; ids is (lut_rank, out_dim) naming which input feeds each slot of each neuron.

Next steps