Skip to content

Heads API

Heads reduce layer outputs to task-specific shapes. Both heads partition the binary feature vector into k equally-sized groups; the variants differ in what they do inside and after each group.

Both heads follow the same training/eval split: training mode applies the float tau/beta scaling used to shape the softmax that typically follows; eval mode drops those scalars and runs the same integer arithmetic the HDL emitter generates, so model.eval()(x) is bit-exact with the deployed hardware's argmax input.

GroupSum

Partitions the last dim into k equally-sized groups and sums each group. Training divides the sum by tau; eval returns the raw integer popcount — the same value the HDL's unsigned argmax consumes.

from bitlogic import GroupSum

head = GroupSum(k=10, tau=30.0)
logits = head(features)   # features: (B, k * m) → (B, k)

GroupedDSP

GroupSum followed by a learnable k×k matmul. Weights are tanh-bounded to [-1, 1] during training and snapped to a signed wbits-bit integer grid in eval mode; the eval forward runs the integer MAC group_sums @ W_int.T directly, matching the emitted SystemVerilog bit-exactly.

from bitlogic import GroupedDSP

head = GroupedDSP(k=10, tau=30.0, wbits=8)

heads

Output heads for LUT networks.

GroupedDSP

GroupedDSP(k: int, tau: float = 1.0, beta: float = 0.0, wbits: int = 8)

Bases: Module

Class-wise group-sum plus a tanh-bounded k×k matmul.

Forward logic (training):

  1. Partition the last dim into k groups of group_size features.
  2. Sum each group (shape (..., k)).
  3. Multiply by tanh(W_raw) and divide by tau (adds beta).

Forward logic (eval): same but the matmul uses the integer-snapped weights q ∈ [-S, S] directly (no / S dequantize, no tau, no beta) — identical to what the HDL emitter produces.

Parameters:

Name Type Description Default
k int

Number of output classes / groups.

required
tau float

Inverse temperature — applied in training mode only. Eval mode skips the divide so the output matches the integer HDL MAC.

1.0
beta float

Additive bias — applied in training mode only, same reason.

0.0
wbits int

Bit-width of the signed integer grid used at eval time. Defaults to 8 to match the encoder input quantization grid.

8

Raises:

Type Description
ValueError

If k <= 0 or wbits < 2.

Source code in bitlogic/heads/grouped_dsp.py
def __init__(
    self,
    k: int,
    tau: float = 1.0,
    beta: float = 0.0,
    wbits: int = 8,
) -> None:
    super().__init__()
    if k <= 0:
        raise ValueError(f"k must be positive, got {k}")
    if wbits < 2:
        raise ValueError(f"wbits must be >= 2, got {wbits}")
    self.k = int(k)
    self.tau = float(tau)
    self.beta = float(beta)
    self.wbits = int(wbits)
    # Xavier-like init: std = 1/sqrt(k) keeps tanh(W) near identity-scale.
    w_raw = torch.randn(self.k, self.k) / math.sqrt(self.k)
    self.W_raw = nn.Parameter(w_raw)

effective_weight

effective_weight() -> Tensor

Return the weight used by forward under the current mode.

Training: tanh(W_raw). Eval: the same, quantize-dequantized on the signed wbits-bit grid (q / S, for inspection only — the eval forward uses the integer weights directly).

Source code in bitlogic/heads/grouped_dsp.py
def effective_weight(self) -> torch.Tensor:
    """Return the weight used by ``forward`` under the current mode.

    Training: ``tanh(W_raw)``. Eval: the same, quantize-dequantized on
    the signed ``wbits``-bit grid (``q / S``, for inspection only — the
    eval forward uses the integer weights directly).
    """
    if self.training:
        return torch.tanh(self.W_raw)
    return weight_quantize_dequantize(self.W_raw, self.wbits)

GroupSum

GroupSum(k: int, tau: float = 1.0, beta: float = 0.0)

Bases: Module

Class-wise group-sum head.

Parameters:

Name Type Description Default
k int

number of output classes (groups).

required
tau float

inverse temperature — applied in training mode only. Eval mode skips the divide so the output matches the integer HDL popcount.

1.0
beta float

additive bias — applied in training mode only, same reason.

0.0
Source code in bitlogic/heads/groupsum.py
def __init__(self, k: int, tau: float = 1.0, beta: float = 0.0) -> None:
    super().__init__()
    self.k = int(k)
    self.tau = float(tau)
    self.beta = float(beta)