Skip to content

perf: Convolution Layout Optimization with Auto-Format Selection #159

@noahgift

Description

@noahgift

Summary

Implement automatic data layout selection for convolutions (NCHW vs NHWC vs custom) based on the backend and operation sequence, with lazy layout conversion at boundaries.

Inspiration: JAX jax/_src/lax/convolution.py:38-180 - supports multiple ConvDimensionNumbers for flexible data layouts.

Problem

Different hardware prefers different data layouts:

  • Intel AVX: NCHW (channels first) for better vectorization
  • ARM NEON: NHWC (channels last) for sequential access
  • cuDNN: Depends on operation (some prefer NCHW, others NHWC)
  • Fused ops: May prefer specific layouts

Current implementation forces a single layout, causing:

  • Unnecessary transposes at operation boundaries
  • Suboptimal memory access patterns
  • Missed fusion opportunities

Proposed Solution

1. Layout Specification

/// Data layout for convolution tensors
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConvLayout {
    /// Batch dimension index
    pub batch: usize,
    /// Channel dimension index
    pub channel: usize,
    /// Spatial dimension indices (height, width, ...)
    pub spatial: [usize; 3],  // Max 3D convolution
    /// Number of spatial dimensions
    pub num_spatial: usize,
}

impl ConvLayout {
    /// NCHW layout (batch, channel, height, width)
    pub const NCHW: Self = Self {
        batch: 0,
        channel: 1,
        spatial: [2, 3, 0],
        num_spatial: 2,
    };
    
    /// NHWC layout (batch, height, width, channel)
    pub const NHWC: Self = Self {
        batch: 0,
        channel: 3,
        spatial: [1, 2, 0],
        num_spatial: 2,
    };
    
    /// NCDHW for 3D convolutions
    pub const NCDHW: Self = Self {
        batch: 0,
        channel: 1,
        spatial: [2, 3, 4],
        num_spatial: 3,
    };
    
    /// Check if this is a "channels first" layout
    pub fn is_channels_first(&self) -> bool {
        self.channel < self.spatial[0]
    }
    
    /// Get dimension permutation from another layout
    pub fn permutation_from(&self, other: &ConvLayout) -> Vec<usize> {
        let mut perm = vec![0; 4.max(2 + self.num_spatial)];
        perm[self.batch] = other.batch;
        perm[self.channel] = other.channel;
        for i in 0..self.num_spatial {
            perm[self.spatial[i]] = other.spatial[i];
        }
        perm
    }
}

/// Full dimension specification including kernel layout
#[derive(Debug, Clone)]
pub struct ConvDimensionNumbers {
    /// Input tensor layout
    pub input_layout: ConvLayout,
    /// Kernel/filter layout (out_channels, in_channels, spatial...)
    pub kernel_layout: KernelLayout,
    /// Output tensor layout
    pub output_layout: ConvLayout,
}

#[derive(Debug, Clone, Copy)]
pub struct KernelLayout {
    /// Output channels dimension
    pub out_channel: usize,
    /// Input channels dimension
    pub in_channel: usize,
    /// Spatial dimensions
    pub spatial: [usize; 3],
    pub num_spatial: usize,
}

impl KernelLayout {
    /// OIHW (out, in, height, width) - standard
    pub const OIHW: Self = Self {
        out_channel: 0,
        in_channel: 1,
        spatial: [2, 3, 0],
        num_spatial: 2,
    };
    
    /// HWIO (height, width, in, out) - TensorFlow style
    pub const HWIO: Self = Self {
        out_channel: 3,
        in_channel: 2,
        spatial: [0, 1, 0],
        num_spatial: 2,
    };
}

2. Layout-Aware Tensor

/// Tensor with explicit layout tracking
pub struct LayoutTensor<T, const N: usize> {
    data: Tensor<T>,
    layout: TensorLayout<N>,
}

#[derive(Debug, Clone)]
pub struct TensorLayout<const N: usize> {
    /// Logical dimension order
    pub dim_order: [usize; N],
    /// Stride for each logical dimension
    pub strides: [usize; N],
}

impl<T: TensorElement, const N: usize> LayoutTensor<T, N> {
    /// Convert to different layout (lazy if possible)
    pub fn to_layout(&self, target: TensorLayout<N>) -> Self {
        if self.layout == target {
            return self.clone();
        }
        
        // Check if we can use a view (just change strides)
        if self.can_view_as(&target) {
            return Self {
                data: self.data.view_with_strides(target.strides),
                layout: target,
            };
        }
        
        // Need actual transpose
        let perm = target.permutation_from(&self.layout);
        Self {
            data: self.data.permute(&perm),
            layout: target,
        }
    }
    
    /// Check if layout conversion can be done as a view
    fn can_view_as(&self, target: &TensorLayout<N>) -> bool {
        // Can view if target strides are a permutation of current strides
        // and data is contiguous in memory
        self.data.is_contiguous() && 
            self.layout.is_permutation_of(target)
    }
}

3. Automatic Layout Selection

/// Layout preferences for different backends
pub struct LayoutPolicy {
    /// Preferred input layout
    pub preferred_input: ConvLayout,
    /// Preferred kernel layout
    pub preferred_kernel: KernelLayout,
    /// Preferred output layout
    pub preferred_output: ConvLayout,
    /// Cost of layout conversion (in equivalent conv FLOPs)
    pub conversion_cost_factor: f64,
}

impl LayoutPolicy {
    /// Policy for x86_64 with AVX2
    #[cfg(target_arch = "x86_64")]
    pub fn default_cpu() -> Self {
        Self {
            preferred_input: ConvLayout::NCHW,
            preferred_kernel: KernelLayout::OIHW,
            preferred_output: ConvLayout::NCHW,
            conversion_cost_factor: 0.1,  // Transpose is 10% of conv cost
        }
    }
    
    /// Policy for ARM with NEON
    #[cfg(target_arch = "aarch64")]
    pub fn default_cpu() -> Self {
        Self {
            preferred_input: ConvLayout::NHWC,
            preferred_kernel: KernelLayout::HWIO,
            preferred_output: ConvLayout::NHWC,
            conversion_cost_factor: 0.1,
        }
    }
    
    /// Policy for WebGPU
    pub fn webgpu() -> Self {
        Self {
            preferred_input: ConvLayout::NHWC,
            preferred_kernel: KernelLayout::HWIO,
            preferred_output: ConvLayout::NHWC,
            conversion_cost_factor: 0.05,  // GPU transposes are cheaper
        }
    }
}

/// Select optimal layout for a sequence of operations
pub fn select_layout_for_sequence(
    ops: &[ConvOp],
    policy: &LayoutPolicy,
) -> Vec<ConvDimensionNumbers> {
    // Dynamic programming to minimize total cost
    let n = ops.len();
    
    // State: (op_index, current_layout) -> min_cost
    let layouts = [ConvLayout::NCHW, ConvLayout::NHWC];
    let mut dp = vec![vec![f64::INFINITY; layouts.len()]; n + 1];
    let mut parent = vec![vec![0usize; layouts.len()]; n + 1];
    
    // Base case: start with preferred layout (cost 0) or convert
    dp[0][0] = 0.0;  // NCHW
    dp[0][1] = policy.conversion_cost_factor;  // NHWC requires convert
    
    for i in 0..n {
        for (j, &layout) in layouts.iter().enumerate() {
            if dp[i][j] == f64::INFINITY {
                continue;
            }
            
            // Cost of this op with this layout
            let op_cost = estimate_conv_cost(&ops[i], layout, policy);
            
            // Try each output layout
            for (k, &out_layout) in layouts.iter().enumerate() {
                let convert_cost = if layout == out_layout {
                    0.0
                } else {
                    policy.conversion_cost_factor * ops[i].output_size() as f64
                };
                
                let total = dp[i][j] + op_cost + convert_cost;
                if total < dp[i + 1][k] {
                    dp[i + 1][k] = total;
                    parent[i + 1][k] = j;
                }
            }
        }
    }
    
    // Backtrack to find optimal sequence
    let mut result = Vec::with_capacity(n);
    let mut current = if dp[n][0] < dp[n][1] { 0 } else { 1 };
    
    for i in (0..n).rev() {
        let in_layout = layouts[parent[i + 1][current]];
        let out_layout = layouts[current];
        
        result.push(ConvDimensionNumbers {
            input_layout: in_layout,
            kernel_layout: policy.preferred_kernel,
            output_layout: out_layout,
        });
        
        current = parent[i + 1][current];
    }
    
    result.reverse();
    result
}

4. Lazy Layout Conversion

/// Tensor that defers layout conversion until needed
pub struct LazyLayoutTensor<T> {
    data: Tensor<T>,
    current_layout: ConvLayout,
    /// Pending layout conversion (if any)
    pending_conversion: Option<ConvLayout>,
}

impl<T: TensorElement> LazyLayoutTensor<T> {
    /// Request a layout change (deferred)
    pub fn request_layout(&mut self, target: ConvLayout) {
        if self.current_layout != target {
            self.pending_conversion = Some(target);
        }
    }
    
    /// Materialize any pending conversion
    pub fn materialize(&mut self) {
        if let Some(target) = self.pending_conversion.take() {
            let perm = target.permutation_from(&self.current_layout);
            self.data = self.data.permute(&perm);
            self.current_layout = target;
        }
    }
    
    /// Get data, materializing if needed
    pub fn get(&mut self) -> &Tensor<T> {
        self.materialize();
        &self.data
    }
    
    /// Try to cancel pending conversion if next op prefers current layout
    pub fn try_cancel_conversion(&mut self, preferred: ConvLayout) -> bool {
        if self.current_layout == preferred {
            self.pending_conversion = None;
            true
        } else {
            false
        }
    }
}

5. Layout-Optimized Convolution

/// Convolution with automatic layout handling
pub fn conv2d_auto_layout<T: TensorElement>(
    input: &LayoutTensor<T, 4>,
    kernel: &LayoutTensor<T, 4>,
    stride: (usize, usize),
    padding: Padding,
) -> LayoutTensor<T, 4> {
    let policy = LayoutPolicy::default_cpu();
    
    // Check if conversion is needed
    let need_input_convert = input.layout() != policy.preferred_input;
    let need_kernel_convert = kernel.layout() != policy.preferred_kernel;
    
    // Estimate costs
    let conv_flops = estimate_conv_flops(input.shape(), kernel.shape(), stride);
    let convert_input_cost = if need_input_convert {
        input.len() as f64 * policy.conversion_cost_factor
    } else {
        0.0
    };
    let convert_kernel_cost = if need_kernel_convert {
        kernel.len() as f64 * policy.conversion_cost_factor
    } else {
        0.0
    };
    
    // Decide whether to convert or use slower non-preferred layout
    let use_preferred = conv_flops as f64 > convert_input_cost + convert_kernel_cost;
    
    if use_preferred {
        // Convert and use optimized kernel
        let input_converted = input.to_layout(policy.preferred_input);
        let kernel_converted = kernel.to_layout(policy.preferred_kernel);
        
        let output = conv2d_optimized(&input_converted, &kernel_converted, stride, padding);
        LayoutTensor::new(output, policy.preferred_output)
    } else {
        // Use current layout with generic kernel
        let output = conv2d_generic(input, kernel, stride, padding);
        LayoutTensor::new(output, input.layout())
    }
}

6. Im2col with Layout Support

/// Im2col that respects input layout
pub fn im2col_with_layout<T: TensorElement>(
    input: &LayoutTensor<T, 4>,
    kernel_size: (usize, usize),
    stride: (usize, usize),
    padding: (usize, usize),
) -> Tensor<T> {
    let layout = input.layout();
    let shape = input.shape();
    
    let (batch, channels, height, width) = if layout.is_channels_first() {
        (shape[0], shape[1], shape[2], shape[3])
    } else {
        (shape[0], shape[3], shape[1], shape[2])
    };
    
    let out_h = (height + 2 * padding.0 - kernel_size.0) / stride.0 + 1;
    let out_w = (width + 2 * padding.1 - kernel_size.1) / stride.1 + 1;
    
    let col_size = channels * kernel_size.0 * kernel_size.1;
    let mut col = Tensor::zeros(&[batch, col_size, out_h * out_w]);
    
    // Layout-aware extraction
    if layout.is_channels_first() {
        im2col_nchw(input.data(), &mut col, kernel_size, stride, padding);
    } else {
        im2col_nhwc(input.data(), &mut col, kernel_size, stride, padding);
    }
    
    col
}

Acceptance Criteria

  • ConvLayout with NCHW, NHWC, NCDHW variants
  • KernelLayout with OIHW, HWIO variants
  • ConvDimensionNumbers for full specification
  • LayoutTensor with explicit layout tracking
  • LayoutPolicy for different backends
  • Dynamic programming layout sequence optimizer
  • LazyLayoutTensor with deferred conversion
  • Im2col with layout support
  • Benchmarks for different layout scenarios

Expected Performance Impact

Scenario Fixed Layout Auto Layout Improvement
Single conv (NCHW native) 10ms 10ms 0%
Single conv (NHWC on NCHW) 15ms 12ms 20%
Conv sequence (5 ops) 75ms 55ms 27%
Mixed precision (convert chain) 100ms 60ms 40%

15-30% improvement for non-native layouts and multi-op sequences.

References

Labels

performance, convolution, layout, P1-high

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions