A convolutional framework for column-wise weight and partial-sum quantization.
The paper has been presented at DATE 2025.
SplitConv4Pim_group.py implements a convolution framework designed for compute-in-memory (CIM) accelerators. This framework supports:
- Weight quantization from layer-wise to column-wise level
- Partial-sum quantization from layer-wise to column-wise level
The primary functionality includes handling quantization using LSQ (Learned Step Size Quantization) for both weights and partial-sums, enabling precise control over various granularities and optimization for CIM architectures.
- Weight Splitting and Quantization: The framework splits weights based on the number of bits per cell and applies quantization through LSQ.
- Partial-Sum Quantization: Supports quantization of partial sums across various granularities.
- Group Convolution Support: Convolutions are performed across groups of arrays, enabling faster operations.
- Row-wise Tiling: Provides flexibility to tile mapped weights row-wise for efficient implementation of partial-sum quantization.
This class implements the core convolution framework for CIM architectures.
w_bit: Number of bits for weight quantization.split_bit: Number of bits per cell.w_mode: Mode for weight processing ('Array'or'Layer').ps_bit: Number of bits for partial-sum quantization.num_sigma: Controls the clipping range for partial sums, restricting values to the range[mu-num_sigma*sigma, mu+num_sigma*sigma].psum_mode: Mode for partial-sum quantization ('Array'or'Layer').in_planes: Number of input channels.planes: Number of output channels.kernel_size: Kernel size of the convolution.N: CIM array size.stride: Convolution stride.padding: Convolution padding.bias: Whether to use a bias term.isRow: Tile weights row-wise ifTrue.w_per_ch: Enables weight quantization on an output-channel basis ifTrue.ps_per_ch: Enables partial-sum quantization on an output-channel basis ifTrue.psumOpt: Enables partial-sum quantization ifTrue.
forward: Executes the forward pass.
Handles layer-wise or channel-wise weight quantization.
Handles array-wise or column-wise weight quantization.
This repository provides three variants of custom convolution layers, each offering a different weight quantization and decomposition methods.
Option 1: Vanilla
conv_module = Conv4Pim_group_split(...)orConv4Pim_group_arr(...)
Option 2: Weight-decomposed version
conv_module = Conv4Pim_group_split_v2(...)orConv4Pim_group_arr_v2(...)- This version incorporates weight decomposition of positive and negative arrays to reflect practical settings of RRAM-based CIM architectures.
Option 3: Weight decomposition + LsqWeight_v3
conv_module = Conv4Pim_group_split_v3(...)orConv4Pim_group_arr_v3(...)- Building upon Option 2, this version integrates an improved LSQ scheme optimized for binary RRAM-based CIM hardware.
LSQ.py follows LSQ (Learned Step Size Quantization) proposed by Steven K. Esser et al. from IBM for both weight and partial-sum quantization, and is implemented based on LSQ-Net repository. Our LSQ implementation is revised to support various granularities.
- Scale Factor Initialization: The scale factors for weight quantization is initialized using the mean absolute value of the tensor. In the case of partial-sum quantization, the scale factors are initialized to constant values.
- Learned Step Sizes: The quantization scale factors are learnable parameters and are updated during training, improving accuracy and adaptability to different network architectures.
- Support for Per-channel Quantization: We provide options to enable per-channel quantization, aligning with the flexibility of LSQ.
import torch
from SplitConv4Pim_group import SplitConv4Pim_group
# Example initialization: Column-wise weight and partial-sum quantization with 256x256 arrays
conv = SplitConv4Pim_group(
w_bit=3,
split_bit=1,
w_mode='Array',
ps_bit=3,
num_sigma=6,
psum_mode='Array',
in_planes=64,
planes=64,
kernel_size=3,
N=256,
stride=1,
padding=1,
bias=False,
isRow=True,
w_per_ch=True,
ps_per_ch=True,
psumOpt=True
)
# Forward pass
input_tensor = torch.randn(1, 64, 32, 32) # Example input tensor
output = conv(input_tensor)