Skip to content

[Arm64] SIMD dot product #37169

@echesakov

Description

@echesakov
class Dp
{
  // int32x2_t vdot_s32 (int32x2_t r, int8x8_t a, int8x8_t b)
  //   A32: VSDOT.S8 Dd, Dn, Dm
  //   A64: SDOT Vd.2S, Vn.8B, Vm.8B
  public static Vector64<int> DotProduct(Vector64<int> addend, Vector64<sbyte> left, Vector64<sbyte> right);

  // uint32x2_t vdot_u32 (uint32x2_t r, uint8x8_t a, uint8x8_t b)
  //   A32: VUDOT.U8 Dd, Dn, Dm
  //   A64: UDOT Vd.2S, Vn.8B, Vm.8B
  public static Vector64<uint> DotProduct(Vector64<uint> addend, Vector64<byte> left, Vector64<byte> right);

  // int32x4_t vdotq_s32 (int32x4_t r, int8x16_t a, int8x16_t b)
  //   A32: VSDOT.S8 Qd, Qn, Qm
  //   A64: SDOT Vd.4S, Vn.16B, Vm.16B
  public static Vector128<int> DotProduct(Vector128<int> addend, Vector128<sbyte> left, Vector128<sbyte> right);

  // uint32x4_t vdotq_u32 (uint32x4_t r, uint8x16_t a, uint8x16_t b)
  //   A32: VUDOT.U8 Qd, Qn, Qm
  //   A64: UDOT Vd.4S, Vn.16B, Vm.16B
  public static Vector128<uint> DotProduct(Vector128<uint> addend, Vector128<byte> left, Vector128<byte> right);

  // int32x2_t vdot_lane_s32 (int32x2_t r, int8x8_t a, int8x8_t b, const int lane)
  //   A32: VSDOT.S8 Dd, Dn, Dm[lane]
  //   A64: SDOT Vd.2S, Vn.8B, Vm.4B[lane]
  public static Vector64<int> DotProductBySelectedQuadruplet(Vector64<int> addend, Vector64<sbyte> left, Vector64<sbyte> right, byte rightScaledIndex);

  // int32x2_t vdot_laneq_s32 (int32x2_t r, int8x8_t a, int8x16_t b, const int lane)
  //   A32: VSDOT.S8 Dd, Dn, Dm[lane]
  //   A64: SDOT Vd.2S, Vn.8B, Vm.4B[lane]
  public static Vector64<int> DotProductBySelectedQuadruplet(Vector64<int> addend, Vector64<sbyte> left, Vector128<sbyte> right, byte rightScaledIndex);

  // uint32x2_t vdot_lane_u32 (uint32x2_t r, uint8x8_t a, uint8x8_t b, const int lane)
  //   A32: VUDOT.U8 Dd, Dn, Dm[lane]
  //   A64: UDOT Vd.2S, Vn.8B, Vm.4B[lane]
  public static Vector64<uint> DotProductBySelectedQuadruplet(Vector64<uint> addend, Vector64<byte> left, Vector64<byte> right, byte rightScaledIndex);

  // uint32x2_t vdot_laneq_u32 (uint32x2_t r, uint8x8_t a, uint8x16_t b, const int lane)
  //   A32: VUDOT.U8 Dd, Dn, Dm[lane]
  //   A64: UDOT Vd.2S, Vn.8B, Vm.4B[lane]
  public static Vector64<uint> DotProductBySelectedQuadruplet(Vector64<uint> addend, Vector64<byte> left, Vector128<byte> right, byte rightScaledIndex);

  // int32x4_t vdotq_laneq_s32 (int32x4_t r, int8x16_t a, int8x16_t b, const int lane)
  //   A32: VSDOT.S8 Qd, Qn, Dm[lane]
  //   A64: SDOT Vd.4S, Vn.16B, Vm.4B[lane]
  public static Vector128<int> DotProductBySelectedQuadruplet(Vector128<int> addend, Vector128<sbyte> left, Vector128<sbyte> right, byte rightScaledIndex);

  // int32x4_t vdotq_lane_s32 (int32x4_t r, int8x16_t a, int8x8_t b, const int lane)
  //   A32: VSDOT.S8 Qd, Qn, Dm[lane]
  //   A64: SDOT Vd.4S, Vn.16B, Vm.4B[lane]
  public static Vector128<int> DotProductBySelectedQuadruplet(Vector128<int> addend, Vector128<sbyte> left, Vector64<sbyte> right, byte rightScaledIndex);

  // uint32x4_t vdotq_laneq_u32 (uint32x4_t r, uint8x16_t a, uint8x16_t b, const int lane)
  //   A32: VUDOT.U8 Qd, Qn, Dm[lane]
  //   A64: UDOT Vd.4S, Vn.16B, Vm.4B[lane]
  public static Vector128<uint> DotProductBySelectedQuadruplet(Vector128<uint> addend, Vector128<byte> left, Vector128<byte> right, byte rightScaledIndex);

  // uint32x4_t vdotq_lane_u32 (uint32x4_t r, uint8x16_t a, uint8x8_t b, const int lane)
  //   A32: VUDOT.U8 Qd, Qn, Dm[lane]
  //   A64: UDOT Vd.4S, Vn.16B, Vm.4B[lane]
  public static Vector128<uint> DotProductBySelectedQuadruplet(Vector128<uint> addend, Vector128<byte> left, Vector64<byte> right, byte rightScaledIndex);
}

This covers only ARMv8.2-DotProd but not the "Mixed sign integer dot product" (USDOT and SUDOT). My understanding that the latter should be under a different ISA class (e.g. MatrixMultiplication). Is my assumption correct?

@TamarChristinaArm @tannergooding @CarolEidt PTAL

The indexed form of these are weird. Here is a description from Arm Architecture Reference Manual

This instruction performs the dot product of the four 8-bit elements in each 32-bit element of the first source register with the four 8-bit elements of an indexed 32-bit element in the second source register, accumulating the result into the corresponding 32-bit element of the destination register.

Another options how this could be expressed in the language is to use vector of ints for the second second source register. Then rightIndex does not have to be a scaled-index.

Vector128<uint> DotProductBySelectedQuadruplet(Vector128<uint> addend, Vector128<byte> left, Vector64<int> right, byte rightIndex)

But, then what should be used as a base type of the second operand for signed vs unsigned dot product? int vs uint?

I didn't like the second variant, so I decided to have index scaled, i.e. JIT will multiply the value by 4 and we expect the range of values for rightScaledIndex to be [0, 1] for (Vector64<sbyte>/Vector64<byte> right) and [0-3] for (Vector128<sbyte>/Vector128<byte> right).

Perhaps, someone can suggest better names for the methods or index operand?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions