Fast training & inference¶
BitLogic ships two complementary accelerators on top of the eager PyTorch forward:
| Mode | Entry point | Use case |
|---|---|---|
torch.compile + AMP |
compile_model |
Training and continuous eval |
| Bit-packed GPU inference | PackedLogicNet |
Deployment-time discrete forward |
Both paths are pure Python — no CUDA extension, no build-time compilation.
Training: compile_model¶
compile_model wraps a model with
torch.compile (CUDA-graph capture via mode="reduce-overhead") and,
optionally, torch.autocast for mixed precision:
import torch
from bitlogic import LogicDense, GroupSum, compile_model
model = torch.nn.Sequential(
LogicDense(784, 4000, parametrization="light", lut_rank=4),
LogicDense(4000, 4000, parametrization="light", lut_rank=4),
GroupSum(k=10, tau=150),
).cuda()
fast = compile_model(model, amp_dtype=torch.bfloat16)
logits = fast(x) # CUDA-graph captured, bf16 autocast
Caveats:
reduce-overheadcaptures CUDA graphs for a fixed batch size; varying the batch triggers a recompile.- Toggling
model.train()/model.eval()picks up a new code path and triggers one recompile per mode.
compile_model ¶
compile_model(model: Module, *, mode: str | bool = 'reduce-overhead', dynamic: bool = False, fullgraph: bool = False, amp_dtype: dtype | None = None, device_type: str | None = None) -> Module
Wrap a bitlogic model for fast training / inference.
Applies (optionally) :class:torch.autocast for mixed precision, then
:func:torch.compile with CUDA-graph capture. The original model is not
mutated; the returned module delegates forward to the wrapped pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Any |
required |
mode
|
str | bool
|
|
'reduce-overhead'
|
dynamic
|
bool
|
Forwarded to |
False
|
fullgraph
|
bool
|
Forwarded to |
False
|
amp_dtype
|
dtype | None
|
Precision for :class: |
None
|
device_type
|
str | None
|
Device type for :class: |
None
|
Returns:
| Type | Description |
|---|---|
Module
|
A |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in bitlogic/compile.py
Inference: PackedLogicNet¶
PackedLogicNet converts a trained model into
bit-packed int64 bitwise ops: 64 samples are evaluated per instruction.
It's the GPU analogue of what torchlogix does for CPU deployment via C
codegen — but stays in pure PyTorch.
from bitlogic import PackedLogicNet
model.eval()
packed = PackedLogicNet(model)
logits = packed(x) # 10-100x faster than model.eval() forward
The wrapper supports Sequential models composed of:
- Pre-logic:
Thermometer/DistributiveThermometer/torch.nn.Flatten. Must output{0, 1}values. - Logic layers:
LogicDense. - Post-logic:
GroupSum.
For parametrizations where the eval forward is already discrete (e.g.
light), the packed output matches model.eval() bit-for-bit. For
parametrizations whose eval forward is continuous (e.g. difflogic),
PackedLogicNet returns the truly-discretized forward defined by the
LUT — which is what you'd actually deploy on hardware.
PackedLogicNet ¶
Bases: Module
Bit-packed discrete inference wrapper for a trained bitlogic model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
A trained :class: |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If the model's submodule layout is not recognized. |
Example
Source code in bitlogic/inference/packed_model.py
forward ¶
Run discrete inference on a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor matching the eager model's input signature. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Head output — same shape and semantics as the eager model's |
Tensor
|
|
Source code in bitlogic/inference/packed_model.py
pack_bits ¶
Pack 64 binary samples along the leading dim into an int64.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor of shape |
Tensor
|
Bit |
Source code in bitlogic/inference/pack.py
unpack_bits ¶
Inverse of :func:pack_bits. Returns the first n samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_packed
|
Tensor
|
Output of :func: |
required |
n
|
int
|
Original number of samples (for trimming the zero-padded tail). |
required |
dtype
|
dtype
|
Dtype of the returned tensor. Use a float dtype for downstream modules that expect floats. |
float32
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Tensor of shape |