Quick Start¶
This guide walks through building and training a LUT network on MNIST using
the top-level bitlogic API. A runnable version lives in
examples/train_mnist.py
— see Examples for invocation details.
graph LR
A["Input<br/>(B, 1, 28, 28)"]
B["DistributiveThermometer<br/>(B, 8, 28, 28)"]
C["Flatten<br/>(B, 6272)"]
D["LogicDense<br/>(B, 2000)"]
E["LogicDense<br/>(B, 2000)"]
F["GroupSum<br/>(B, 10)"]
A --> B --> C --> D --> E --> F
1. Install¶
Pick a hardware extra and sync:
git clone https://github.com/aplesner/bitlogic.git
cd bitlogic
uv sync --extra cpu # or: --extra cu128
See the Installation Guide for CUDA / pip details.
2. Imports¶
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from bitlogic import DistributiveThermometer, LogicDense, GroupSum
Everything the user typically needs is re-exported from the bitlogic
top-level module. Factories live under
bitlogic.parametrizations and
bitlogic.connections for advanced use.
3. Build the model¶
BitLogic is just torch.nn — you assemble encoders, layers, and heads with
nn.Sequential (or any other container).
model = nn.Sequential(
DistributiveThermometer(num_bits=8), # (B, 1, 28, 28) → (B, 8, 28, 28)
nn.Flatten(), # → (B, 6272)
LogicDense(
in_dim=6272, out_dim=2000,
parametrization="light", lut_rank=4, temperature=1.0,
connections="learnable", num_candidates=8,
),
LogicDense(
in_dim=2000, out_dim=2000,
parametrization="light", lut_rank=4, temperature=1.0,
connections="learnable", num_candidates=8,
),
GroupSum(k=10, tau=30.0),
)
Anatomy:
- Encoder —
DistributiveThermometer(num_bits=8)turns each scalar input into 8 bits using quantile thresholds fit on training data. The defaultencode_axis=1keeps NCHW layout: an(N, 1, 28, 28)image becomes(N, 8, 28, 28). - LogicDense — each output neuron reads
lut_rank=4inputs (selected by thelearnableconnections module with 8 candidate inputs per slot) and evaluates a LUT through the"light"parametrization. - GroupSum — partitions the final
2000features into 10 groups of 200 and sums each group, giving class logits divided bytau.
4. Fit the encoder¶
Thermometer encoders need one pass over training data to compute thresholds. Do this once, before training.
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=128, shuffle=False)
# Collect enough training samples to estimate quantiles, then fit the encoder.
with torch.no_grad():
sample_batches = []
for imgs, _ in train_loader:
sample_batches.append(imgs)
if sum(b.shape[0] for b in sample_batches) >= 10000:
break
fit_sample = torch.cat(sample_batches, dim=0)
model[0].fit(fit_sample) # DistributiveThermometer is the first module
5. Train¶
Standard PyTorch — nothing BitLogic-specific:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
opt = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(5):
model.train()
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
opt.zero_grad()
logits = model(imgs)
loss = loss_fn(logits, labels)
loss.backward()
opt.step()
model.eval()
correct, total = 0, 0
with torch.no_grad():
for imgs, labels in test_loader:
imgs, labels = imgs.to(device), labels.to(device)
preds = model(imgs).argmax(-1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"epoch {epoch}: test accuracy = {100 * correct / total:.2f}%")
6. Inspect the learned LUTs¶
Every LogicDense layer can emit discretized truth tables and input-id
routing — the inputs to an HDL export backend:
for i, mod in enumerate(model):
if isinstance(mod, LogicDense):
luts, ids = mod.get_luts_and_ids()
print(f"layer {i}: luts {tuple(luts.shape)} ids {tuple(ids.shape)}")
luts is (out_dim, 2**lut_rank) of {0, 1} entries; ids is
(lut_rank, out_dim) naming which input feeds each slot of each neuron.
Next steps¶
- Parametrizations — trade-offs between the eight LUT shapes.
- Connections — fixed vs. learnable vs. top-K sparse.
- Examples — the runnable MNIST quickstart that matches this guide.
- Contributing — dev workflow and conventions.