Skip to content

Fast training & inference

BitLogic ships two complementary accelerators on top of the eager PyTorch forward:

Mode Entry point Use case
torch.compile + AMP compile_model Training and continuous eval
Bit-packed GPU inference PackedLogicNet Deployment-time discrete forward

Both paths are pure Python — no CUDA extension, no build-time compilation.

Training: compile_model

compile_model wraps a model with torch.compile (CUDA-graph capture via mode="reduce-overhead") and, optionally, torch.autocast for mixed precision:

import torch
from bitlogic import LogicDense, GroupSum, compile_model

model = torch.nn.Sequential(
    LogicDense(784, 4000, parametrization="light", lut_rank=4),
    LogicDense(4000, 4000, parametrization="light", lut_rank=4),
    GroupSum(k=10, tau=150),
).cuda()

fast = compile_model(model, amp_dtype=torch.bfloat16)
logits = fast(x)  # CUDA-graph captured, bf16 autocast

Caveats:

  • reduce-overhead captures CUDA graphs for a fixed batch size; varying the batch triggers a recompile.
  • Toggling model.train() / model.eval() picks up a new code path and triggers one recompile per mode.

compile_model

compile_model(model: Module, *, mode: str | bool = 'reduce-overhead', dynamic: bool = False, fullgraph: bool = False, amp_dtype: dtype | None = None, device_type: str | None = None) -> Module

Wrap a bitlogic model for fast training / inference.

Applies (optionally) :class:torch.autocast for mixed precision, then :func:torch.compile with CUDA-graph capture. The original model is not mutated; the returned module delegates forward to the wrapped pipeline.

Parameters:

Name Type Description Default
model Module

Any nn.Module — typically a nn.Sequential of :class:~bitlogic.layers.LogicDense layers plus a head.

required
mode str | bool

torch.compile mode — one of "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs". Pass True as a shorthand for "reduce-overhead" and False to disable compilation (autocast still applies if amp_dtype is set).

'reduce-overhead'
dynamic bool

Forwarded to torch.compile; set to True if batch sizes vary at runtime (disables CUDA-graph capture).

False
fullgraph bool

Forwarded to torch.compile; when True raise on graph breaks instead of falling back to eager.

False
amp_dtype dtype | None

Precision for :class:torch.autocast. None disables autocast; use torch.bfloat16 for modern GPUs, torch.float16 only if bf16 is unavailable.

None
device_type str | None

Device type for :class:torch.autocast. Inferred from the first parameter when None.

None

Returns:

Type Description
Module

A nn.Module that runs the compiled forward.

Raises:

Type Description
ValueError

If mode is not one of the recognized strings.

Source code in bitlogic/compile.py
def compile_model(
    model: nn.Module,
    *,
    mode: str | bool = "reduce-overhead",
    dynamic: bool = False,
    fullgraph: bool = False,
    amp_dtype: torch.dtype | None = None,
    device_type: str | None = None,
) -> nn.Module:
    """Wrap a bitlogic model for fast training / inference.

    Applies (optionally) :class:`torch.autocast` for mixed precision, then
    :func:`torch.compile` with CUDA-graph capture. The original model is not
    mutated; the returned module delegates forward to the wrapped pipeline.

    Args:
        model: Any ``nn.Module`` — typically a ``nn.Sequential`` of
            :class:`~bitlogic.layers.LogicDense` layers plus a head.
        mode: ``torch.compile`` mode — one of ``"default"``,
            ``"reduce-overhead"``, ``"max-autotune"``,
            ``"max-autotune-no-cudagraphs"``. Pass ``True`` as a shorthand
            for ``"reduce-overhead"`` and ``False`` to disable compilation
            (autocast still applies if ``amp_dtype`` is set).
        dynamic: Forwarded to ``torch.compile``; set to ``True`` if batch
            sizes vary at runtime (disables CUDA-graph capture).
        fullgraph: Forwarded to ``torch.compile``; when ``True`` raise on
            graph breaks instead of falling back to eager.
        amp_dtype: Precision for :class:`torch.autocast`. ``None`` disables
            autocast; use ``torch.bfloat16`` for modern GPUs, ``torch.float16``
            only if bf16 is unavailable.
        device_type: Device type for :class:`torch.autocast`. Inferred from
            the first parameter when ``None``.

    Returns:
        A ``nn.Module`` that runs the compiled forward.

    Raises:
        ValueError: If ``mode`` is not one of the recognized strings.
    """
    if isinstance(mode, bool):
        mode_str: str | None = "reduce-overhead" if mode else None
    else:
        if mode not in _VALID_MODES:
            raise ValueError(f"mode must be one of {_VALID_MODES}, got {mode!r}")
        mode_str = mode

    wrapped: nn.Module = model
    if amp_dtype is not None:
        if device_type is None:
            try:
                device_type = next(model.parameters()).device.type
            except StopIteration:
                device_type = "cuda" if torch.cuda.is_available() else "cpu"
        wrapped = _AutocastWrapper(model, device_type=device_type, dtype=amp_dtype)

    if mode_str is None:
        return wrapped

    return torch.compile(wrapped, mode=mode_str, dynamic=dynamic, fullgraph=fullgraph)

Inference: PackedLogicNet

PackedLogicNet converts a trained model into bit-packed int64 bitwise ops: 64 samples are evaluated per instruction. It's the GPU analogue of what torchlogix does for CPU deployment via C codegen — but stays in pure PyTorch.

from bitlogic import PackedLogicNet

model.eval()
packed = PackedLogicNet(model)
logits = packed(x)           # 10-100x faster than model.eval() forward

The wrapper supports Sequential models composed of:

  1. Pre-logic: Thermometer / DistributiveThermometer / torch.nn.Flatten. Must output {0, 1} values.
  2. Logic layers: LogicDense.
  3. Post-logic: GroupSum.

For parametrizations where the eval forward is already discrete (e.g. light), the packed output matches model.eval() bit-for-bit. For parametrizations whose eval forward is continuous (e.g. difflogic), PackedLogicNet returns the truly-discretized forward defined by the LUT — which is what you'd actually deploy on hardware.

PackedLogicNet

PackedLogicNet(model: Module)

Bases: Module

Bit-packed discrete inference wrapper for a trained bitlogic model.

Parameters:

Name Type Description Default
model Module

A trained :class:~torch.nn.Sequential of pre-modules (encoders / flatten), LogicDense layers, and post-modules (head). The wrapped model is put into eval mode but not mutated; construction calls :meth:~bitlogic.layers.LogicDense.get_luts_and_ids on each logic layer under torch.no_grad.

required

Raises:

Type Description
TypeError

If the model's submodule layout is not recognized.

Example
model.eval()
packed = PackedLogicNet(model)
logits = packed(x)          # 10-100x faster than model.eval() forward
Source code in bitlogic/inference/packed_model.py
def __init__(self, model: nn.Module):
    super().__init__()
    model.eval()

    children = list(_iter_children(model))
    pre: list[nn.Module] = []
    logic: list[LogicDense] = []
    post: list[nn.Module] = []
    stage = 0  # 0=pre, 1=logic, 2=post
    for child in children:
        if stage == 0 and isinstance(child, _PRE_TYPES):
            pre.append(child)
        elif isinstance(child, LogicDense):
            stage = 1
            logic.append(child)
        elif isinstance(child, _POST_TYPES):
            stage = 2
            post.append(child)
        else:
            raise TypeError(
                f"PackedLogicNet does not support submodule of type "
                f"{type(child).__name__} at position {len(pre) + len(logic) + len(post)}. "
                f"Supported: encoders/Flatten, LogicDense, GroupSum."
            )
    if not logic:
        raise TypeError("PackedLogicNet requires at least one LogicDense layer")

    self.pre = nn.Sequential(*pre) if pre else nn.Identity()
    self.post = nn.Sequential(*post) if post else nn.Identity()

    # Extract discrete LUTs + routing ids for every logic layer. All
    # parametrizations return LUTs in MSB convention (input ``j`` at bit
    # ``rank-1-j`` of the address), matched by ``eval_packed_logic_dense``.
    luts_list: list[torch.Tensor] = []
    ids_list: list[torch.Tensor] = []
    ranks: list[int] = []
    with torch.no_grad():
        for layer in logic:
            rank = layer.parametrization.lut_rank
            luts, ids = layer.get_luts_and_ids()
            luts_list.append(luts.detach())
            ids_list.append(ids.detach().to(torch.int64))
            ranks.append(rank)
    self._num_layers = len(logic)
    self._ranks: tuple[int, ...] = tuple(ranks)
    for i, (luts, ids) in enumerate(zip(luts_list, ids_list, strict=True)):
        self.register_buffer(f"_lut_{i}", luts)
        self.register_buffer(f"_ids_{i}", ids)

forward

forward(x: Tensor) -> Tensor

Run discrete inference on a batch.

Parameters:

Name Type Description Default
x Tensor

Input tensor matching the eager model's input signature.

required

Returns:

Type Description
Tensor

Head output — same shape and semantics as the eager model's

Tensor

.eval() forward.

Source code in bitlogic/inference/packed_model.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Run discrete inference on a batch.

    Args:
        x: Input tensor matching the eager model's input signature.

    Returns:
        Head output — same shape and semantics as the eager model's
        ``.eval()`` forward.
    """
    # 1. Pre: encoders + flatten, kept in fp for correctness.
    feats = self.pre(x)
    if feats.ndim < 2:
        raise ValueError(
            f"pre-logic output must be at least 2-D (batch, features), got {feats.shape}"
        )
    # Collapse all trailing dims into one feature axis.
    batch = feats.shape[0]
    feats = feats.reshape(batch, -1)

    # 2. Pack along batch dim.
    packed = pack_bits(feats)  # (packs, in_dim) int64

    # 3. Per-layer bitwise LUT eval.
    for i, rank in enumerate(self._ranks):
        luts = self._lut(i)
        ids = self._ids(i)
        if rank == 2:
            packed = eval_packed_logic_dense_rank2(packed, luts, ids)
        else:
            packed = eval_packed_logic_dense(packed, luts, ids)

    # 4. Unpack to fp {0,1} in the original batch order.
    unpacked = unpack_bits(packed, batch, dtype=feats.dtype)

    # 5. Post: head (e.g. GroupSum) in fp.
    return self.post(unpacked)

pack_bits

pack_bits(x: Tensor) -> Tensor

Pack 64 binary samples along the leading dim into an int64.

Parameters:

Name Type Description Default
x Tensor

Tensor of shape (N, *feat) with values in {0, 1}. N is padded to the next multiple of 64 with zeros.

required

Returns:

Type Description
Tensor

Tensor of shape (ceil(N / 64), *feat) with dtype torch.int64.

Tensor

Bit i of the output at position (p, *f) holds x[p * 64 + i, *f].

Source code in bitlogic/inference/pack.py
def pack_bits(x: torch.Tensor) -> torch.Tensor:
    """Pack 64 binary samples along the leading dim into an ``int64``.

    Args:
        x: Tensor of shape ``(N, *feat)`` with values in ``{0, 1}``. ``N`` is
            padded to the next multiple of 64 with zeros.

    Returns:
        Tensor of shape ``(ceil(N / 64), *feat)`` with dtype ``torch.int64``.
        Bit ``i`` of the output at position ``(p, *f)`` holds ``x[p * 64 + i, *f]``.
    """
    if x.ndim == 0:
        raise ValueError("pack_bits expects at least a 1-D input")
    n = x.shape[0]
    pad = (-n) % PACK_BITS
    if pad:
        x = torch.cat(
            [x, x.new_zeros((pad,) + tuple(x.shape[1:]))],
            dim=0,
        )
    x_int = x.to(torch.int64)
    packs = x_int.shape[0] // PACK_BITS
    x_grouped = x_int.reshape(packs, PACK_BITS, *x_int.shape[1:])
    shifts = torch.arange(PACK_BITS, device=x.device, dtype=torch.int64).view(
        1, PACK_BITS, *([1] * (x_int.ndim - 1))
    )
    # Non-overlapping bits — sum is equivalent to bitwise OR reduction.
    return (x_grouped << shifts).sum(dim=1)

unpack_bits

unpack_bits(x_packed: Tensor, n: int, *, dtype: dtype = float32) -> Tensor

Inverse of :func:pack_bits. Returns the first n samples.

Parameters:

Name Type Description Default
x_packed Tensor

Output of :func:pack_bits, shape (P, *feat) int64.

required
n int

Original number of samples (for trimming the zero-padded tail).

required
dtype dtype

Dtype of the returned tensor. Use a float dtype for downstream modules that expect floats.

float32

Returns:

Type Description
Tensor

Tensor of shape (n, *feat) with values in {0, 1}.

Source code in bitlogic/inference/pack.py
def unpack_bits(
    x_packed: torch.Tensor, n: int, *, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
    """Inverse of :func:`pack_bits`. Returns the first ``n`` samples.

    Args:
        x_packed: Output of :func:`pack_bits`, shape ``(P, *feat)`` int64.
        n: Original number of samples (for trimming the zero-padded tail).
        dtype: Dtype of the returned tensor. Use a float dtype for
            downstream modules that expect floats.

    Returns:
        Tensor of shape ``(n, *feat)`` with values in ``{0, 1}``.
    """
    shifts = torch.arange(PACK_BITS, device=x_packed.device, dtype=torch.int64).view(
        1, PACK_BITS, *([1] * (x_packed.ndim - 1))
    )
    expanded = (x_packed.unsqueeze(1) >> shifts) & 1
    flat = expanded.reshape(-1, *x_packed.shape[1:])
    return flat[:n].to(dtype)