From 5265304cdf9b086bd7cd707fa30b642e913d7732 Mon Sep 17 00:00:00 2001 From: "Plyakhin, Yury" Date: Tue, 25 Feb 2025 00:53:03 +0000 Subject: [PATCH] Joint Matrix: fix Fill Checked and add MAD big shapes - Implemented FIll Checked built-in for N=64 (e.g. 32x64x16 shape) - Implemented MAD built-ins for big shapes for half type --- .../OpenCL/PreRelease/IBiF_matrix.cl | 316 +++++++++--------- 1 file changed, 165 insertions(+), 151 deletions(-) diff --git a/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl b/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl index 4f25973063a4..d2806c5f8df5 100644 --- a/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl +++ b/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl @@ -1440,7 +1440,9 @@ DEFINE_GET_COORD(Accumulator, , 32, 32, 8, 8, 1) /* experimental large slice support: */ -#define DEFINE_MAD_16x16x16_IMPL(a_type, b_type, a_suffix, b_suffix) \ +// MAD: + +#define DEFINE_MAD_LARGE_SLICE_16bit_AB_IMPL(a_type, b_type, a_suffix, b_suffix) \ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) { \ short16 a = *(short16 *)a_ptr; \ int8 b = *(int8 *)b_ptr; \ @@ -1463,156 +1465,157 @@ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type## __private int16 *dst = (__private int16 *)result; \ *dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7, \ res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7); \ +} \ +\ +INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x32x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \ + int8 a0 = *(int8 *)a_ptr; \ + int8 a1 = *(int8 *) (a_ptr + 1 * 16 * (sizeof (short))); \ + int8 a2 = *(int8 *) (a_ptr + 2 * 16 * (sizeof (short))); \ + int8 a3 = *(int8 *) (a_ptr + 3 * 16 * (sizeof (short))); \ +\ + int8 b0 = *(int8 *)b_ptr; \ + int8 b1 = *(int8 *) (b_ptr + 1 * 16 * (sizeof (short))); \ + int8 b2 = *(int8 *) (b_ptr + 2 * 16 * (sizeof (short))); \ + int8 b3 = *(int8 *) (b_ptr + 3 * 16 * (sizeof (short))); \ +\ + float8 c0 = *(float8 *) (c_ptr + 0 * 8 * (sizeof (int))); \ + float8 c1 = *(float8 *) (c_ptr + 4 * 8 * (sizeof (int))); \ + float8 c2 = *(float8 *) (c_ptr + 8 * 8 * (sizeof (int))); \ + float8 c3 = *(float8 *) (c_ptr + 12 * 8 * (sizeof (int))); \ + float8 c4 = *(float8 *) (c_ptr + 1 * 8 * (sizeof (int))); \ + float8 c5 = *(float8 *) (c_ptr + 5 * 8 * (sizeof (int))); \ + float8 c6 = *(float8 *) (c_ptr + 9 * 8 * (sizeof (int))); \ + float8 c7 = *(float8 *) (c_ptr + 13 * 8 * (sizeof (int))); \ + float8 c8 = *(float8 *) (c_ptr + 2 * 8 * (sizeof (int))); \ + float8 c9 = *(float8 *) (c_ptr + 6 * 8 * (sizeof (int))); \ + float8 c10 = *(float8 *) (c_ptr + 10 * 8 * (sizeof (int))); \ + float8 c11 = *(float8 *) (c_ptr + 14 * 8 * (sizeof (int))); \ + float8 c12 = *(float8 *) (c_ptr + 3 * 8 * (sizeof (int))); \ + float8 c13 = *(float8 *) (c_ptr + 7 * 8 * (sizeof (int))); \ + float8 c14 = *(float8 *) (c_ptr + 11 * 8 * (sizeof (int))); \ + float8 c15 = *(float8 *) (c_ptr + 15 * 8 * (sizeof (int))); \ +\ + *(float8 *) (d_ptr + 0 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c0, a0, b0); \ + *(float8 *) (d_ptr + 4 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c1, a0, b1); \ + *(float8 *) (d_ptr + 8 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c2, a0, b2); \ + *(float8 *) (d_ptr + 12 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c3, a0, b3); \ + *(float8 *) (d_ptr + 1 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c4, a1, b0); \ + *(float8 *) (d_ptr + 5 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c5, a1, b1); \ + *(float8 *) (d_ptr + 9 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c6, a1, b2); \ + *(float8 *) (d_ptr + 13 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c7, a1, b3); \ + *(float8 *) (d_ptr + 2 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c8, a2, b0); \ + *(float8 *) (d_ptr + 6 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c9, a2, b1); \ + *(float8 *) (d_ptr + 10 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c10, a2, b2); \ + *(float8 *) (d_ptr + 14 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c11, a2, b3); \ + *(float8 *) (d_ptr + 3 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c12, a3, b0); \ + *(float8 *) (d_ptr + 7 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c13, a3, b1); \ + *(float8 *) (d_ptr + 11 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c14, a3, b2); \ + *(float8 *) (d_ptr + 15 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c15, a3, b3); \ +} \ +\ +INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \ + short a = *(short *) a_ptr; \ +\ + int8 b0 = *(int8 *) b_ptr; \ + int8 b1 = *(int8 *)(b_ptr + 1 * 16 * (sizeof (short))); \ + int8 b2 = *(int8 *)(b_ptr + 2 * 16 * (sizeof (short))); \ + int8 b3 = *(int8 *)(b_ptr + 3 * 16 * (sizeof (short))); \ +\ + float c0 = *(float *) c_ptr; \ + float c1 = *(float *) (c_ptr + 1 * (sizeof (int))); \ + float c2 = *(float *) (c_ptr + 2 * (sizeof (int))); \ + float c3 = *(float *) (c_ptr + 3 * (sizeof (int))); \ +\ + float d0 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c0, a, b0); \ + float d1 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c1, a, b1); \ + float d2 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c2, a, b2); \ + float d3 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c3, a, b3); \ +\ + __private int4 *dst = (__private int4 *)d_ptr; \ + *dst = (int4)(as_int(d0), as_int(d1), as_int(d2), as_int(d3)); \ +} \ +\ +INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \ + __private char *a0 = a_ptr; \ + __private char *a1 = a_ptr + 16 * (sizeof (short)); \ +\ + __private char *b0 = b_ptr; \ + __private char *b1 = b_ptr + 1 * 16 * (sizeof (short)); \ + __private char *b2 = b_ptr + 2 * 16 * (sizeof (short)); \ + __private char *b3 = b_ptr + 3 * 16 * (sizeof (short)); \ +\ + __private char *c0 = c_ptr + 0 * 16 * (sizeof (int)); \ + __private char *c1 = c_ptr + 2 * 16 * (sizeof (int)); \ + __private char *c2 = c_ptr + 4 * 16 * (sizeof (int)); \ + __private char *c3 = c_ptr + 6 * 16 * (sizeof (int)); \ + __private char *c4 = c_ptr + 1 * 16 * (sizeof (int)); \ + __private char *c5 = c_ptr + 3 * 16 * (sizeof (int)); \ + __private char *c6 = c_ptr + 5 * 16 * (sizeof (int)); \ + __private char *c7 = c_ptr + 7 * 16 * (sizeof (int)); \ +\ + __private char *d0 = d_ptr + 0 * 16 * (sizeof (int)); \ + __private char *d1 = d_ptr + 2 * 16 * (sizeof (int)); \ + __private char *d2 = d_ptr + 4 * 16 * (sizeof (int)); \ + __private char *d3 = d_ptr + 6 * 16 * (sizeof (int)); \ + __private char *d4 = d_ptr + 1 * 16 * (sizeof (int)); \ + __private char *d5 = d_ptr + 3 * 16 * (sizeof (int)); \ + __private char *d6 = d_ptr + 5 * 16 * (sizeof (int)); \ + __private char *d7 = d_ptr + 7 * 16 * (sizeof (int)); \ +\ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b0, c0, d0); \ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b1, c1, d1); \ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b2, c2, d2); \ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b3, c3, d3); \ +\ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b0, c4, d4); \ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b1, c5, d5); \ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b2, c6, d6); \ + __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b3, c7, d7); \ +} \ +\ +INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x32_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \ + short8 a[8]; \ + int8 b[8]; \ + for (int i = 0; i < 8; i++) { \ + a[i] = *(short8 *)(a_ptr + i * 8 * (sizeof (short))); \ + b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int))); \ + } \ +\ + float8 c[16]; \ + for (int i = 0; i < 16; i++) \ + c[i] = *(float8 *)(c_ptr + i * 8 * (sizeof (int))); \ +\ +_Pragma("unroll") /* TODO: investigate, why not unrolling the loop causes wrong code generated*/ \ + for (int i = 0; i < 4; i++) { \ + for (int j = 0; j < 4; j++) { \ + float8 d = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(c[i + 4*j], a[i], b[2*j]); \ + *(float8 *)(d_ptr + (i + 4*j) * 8 * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(d, a[i + 4], b[2*j + 1]); \ + } \ + } \ +} \ +\ +INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x32_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \ + short a0 = *(short *)(a_ptr + 0 * (sizeof (short))); \ + short a1 = *(short *)(a_ptr + 1 * (sizeof (short))); \ +\ + int8 b[8]; \ + for (int i = 0; i < 8; i++) \ + b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int))); \ +\ + float c[4]; \ + for (int i = 0; i < 4; i++) \ + c[i] = *(float *)(c_ptr + i * (sizeof (int))); \ +\ + for (int i = 0; i < 4; i++) { \ + float d = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c[i], a0, b[2 * i]); \ + *(float *)(d_ptr + i * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(d, a1, b[2 * i + 1]); \ + } \ } -DEFINE_MAD_16x16x16_IMPL(bf16, bf16, bf, bf) -DEFINE_MAD_16x16x16_IMPL(fp16, fp16, hf, hf) - -// Splitting a 32x32x16 MAD operation into sixteen of 8x8x8 MAD operations -INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x32x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { - int8 a0 = *(int8 *)a_ptr; - int8 a1 = *(int8 *) (a_ptr + 1 * 16 * (sizeof (short))); - int8 a2 = *(int8 *) (a_ptr + 2 * 16 * (sizeof (short))); - int8 a3 = *(int8 *) (a_ptr + 3 * 16 * (sizeof (short))); - - int8 b0 = *(int8 *)b_ptr; - int8 b1 = *(int8 *) (b_ptr + 1 * 16 * (sizeof (short))); - int8 b2 = *(int8 *) (b_ptr + 2 * 16 * (sizeof (short))); - int8 b3 = *(int8 *) (b_ptr + 3 * 16 * (sizeof (short))); - - float8 c0 = *(float8 *) (c_ptr + 0 * 8 * (sizeof (int))); - float8 c1 = *(float8 *) (c_ptr + 4 * 8 * (sizeof (int))); - float8 c2 = *(float8 *) (c_ptr + 8 * 8 * (sizeof (int))); - float8 c3 = *(float8 *) (c_ptr + 12 * 8 * (sizeof (int))); - float8 c4 = *(float8 *) (c_ptr + 1 * 8 * (sizeof (int))); - float8 c5 = *(float8 *) (c_ptr + 5 * 8 * (sizeof (int))); - float8 c6 = *(float8 *) (c_ptr + 9 * 8 * (sizeof (int))); - float8 c7 = *(float8 *) (c_ptr + 13 * 8 * (sizeof (int))); - float8 c8 = *(float8 *) (c_ptr + 2 * 8 * (sizeof (int))); - float8 c9 = *(float8 *) (c_ptr + 6 * 8 * (sizeof (int))); - float8 c10 = *(float8 *) (c_ptr + 10 * 8 * (sizeof (int))); - float8 c11 = *(float8 *) (c_ptr + 14 * 8 * (sizeof (int))); - float8 c12 = *(float8 *) (c_ptr + 3 * 8 * (sizeof (int))); - float8 c13 = *(float8 *) (c_ptr + 7 * 8 * (sizeof (int))); - float8 c14 = *(float8 *) (c_ptr + 11 * 8 * (sizeof (int))); - float8 c15 = *(float8 *) (c_ptr + 15 * 8 * (sizeof (int))); - - *(float8 *) (d_ptr + 0 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c0, a0, b0); - *(float8 *) (d_ptr + 4 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c1, a0, b1); - *(float8 *) (d_ptr + 8 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c2, a0, b2); - *(float8 *) (d_ptr + 12 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c3, a0, b3); - *(float8 *) (d_ptr + 1 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c4, a1, b0); - *(float8 *) (d_ptr + 5 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c5, a1, b1); - *(float8 *) (d_ptr + 9 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c6, a1, b2); - *(float8 *) (d_ptr + 13 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c7, a1, b3); - *(float8 *) (d_ptr + 2 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c8, a2, b0); - *(float8 *) (d_ptr + 6 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c9, a2, b1); - *(float8 *) (d_ptr + 10 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c10, a2, b2); - *(float8 *) (d_ptr + 14 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c11, a2, b3); - *(float8 *) (d_ptr + 3 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c12, a3, b0); - *(float8 *) (d_ptr + 7 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c13, a3, b1); - *(float8 *) (d_ptr + 11 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c14, a3, b2); - *(float8 *) (d_ptr + 15 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c15, a3, b3); -} - -INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { - short a = *(short *) a_ptr; - - int8 b0 = *(int8 *) b_ptr; - int8 b1 = *(int8 *)(b_ptr + 1 * 16 * (sizeof (short))); - int8 b2 = *(int8 *)(b_ptr + 2 * 16 * (sizeof (short))); - int8 b3 = *(int8 *)(b_ptr + 3 * 16 * (sizeof (short))); - - float c0 = *(float *) c_ptr; - float c1 = *(float *) (c_ptr + 1 * (sizeof (int))); - float c2 = *(float *) (c_ptr + 2 * (sizeof (int))); - float c3 = *(float *) (c_ptr + 3 * (sizeof (int))); - - float d0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c0, a, b0); - float d1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c1, a, b1); - float d2 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c2, a, b2); - float d3 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c3, a, b3); - - __private int4 *dst = (__private int4 *)d_ptr; - *dst = (int4)(as_int(d0), as_int(d1), as_int(d2), as_int(d3)); -} - -INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { - __private char *a0 = a_ptr; - __private char *a1 = a_ptr + 16 * (sizeof (short)); - - __private char *b0 = b_ptr; - __private char *b1 = b_ptr + 1 * 16 * (sizeof (short)); - __private char *b2 = b_ptr + 2 * 16 * (sizeof (short)); - __private char *b3 = b_ptr + 3 * 16 * (sizeof (short)); - - __private char *c0 = c_ptr + 0 * 16 * (sizeof (int)); - __private char *c1 = c_ptr + 2 * 16 * (sizeof (int)); - __private char *c2 = c_ptr + 4 * 16 * (sizeof (int)); - __private char *c3 = c_ptr + 6 * 16 * (sizeof (int)); - __private char *c4 = c_ptr + 1 * 16 * (sizeof (int)); - __private char *c5 = c_ptr + 3 * 16 * (sizeof (int)); - __private char *c6 = c_ptr + 5 * 16 * (sizeof (int)); - __private char *c7 = c_ptr + 7 * 16 * (sizeof (int)); - - __private char *d0 = d_ptr + 0 * 16 * (sizeof (int)); - __private char *d1 = d_ptr + 2 * 16 * (sizeof (int)); - __private char *d2 = d_ptr + 4 * 16 * (sizeof (int)); - __private char *d3 = d_ptr + 6 * 16 * (sizeof (int)); - __private char *d4 = d_ptr + 1 * 16 * (sizeof (int)); - __private char *d5 = d_ptr + 3 * 16 * (sizeof (int)); - __private char *d6 = d_ptr + 5 * 16 * (sizeof (int)); - __private char *d7 = d_ptr + 7 * 16 * (sizeof (int)); - - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b0, c0, d0); - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b1, c1, d1); - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b2, c2, d2); - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b3, c3, d3); - - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b0, c4, d4); - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b1, c5, d5); - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b2, c6, d6); - __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b3, c7, d7); -} - -INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x32_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { - short8 a[8]; - int8 b[8]; - for (int i = 0; i < 8; i++) { - a[i] = *(short8 *)(a_ptr + i * 8 * (sizeof (short))); - b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int))); - } - - float8 c[16]; - for (int i = 0; i < 16; i++) - c[i] = *(float8 *)(c_ptr + i * 8 * (sizeof (int))); - -#pragma unroll // TODO: investigate, why not unrolling the loop causes wrong code generated - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - float8 d = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c[i + 4*j], a[i], b[2*j]); - *(float8 *)(d_ptr + (i + 4*j) * 8 * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(d, a[i + 4], b[2*j + 1]); - } - } -} - -INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x32_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { - short a0 = *(short *)(a_ptr + 0 * (sizeof (short))); - short a1 = *(short *)(a_ptr + 1 * (sizeof (short))); +DEFINE_MAD_LARGE_SLICE_16bit_AB_IMPL(bf16, bf16, bf, bf) +DEFINE_MAD_LARGE_SLICE_16bit_AB_IMPL(fp16, fp16, hf, hf) - int8 b[8]; - for (int i = 0; i < 8; i++) - b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int))); - - float c[4]; - for (int i = 0; i < 4; i++) - c[i] = *(float *)(c_ptr + i * (sizeof (int))); - - for (int i = 0; i < 4; i++) { - float d = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c[i], a0, b[2 * i]); - *(float *)(d_ptr + i * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(d, a1, b[2 * i + 1]); - } -} +// LOADS: /* PackedA load i16 for big shapes */ DEFINE_LOAD(PackedA_RowMajor, , short, int, 32, 16, ROW_MAJOR, , 32) @@ -2310,10 +2313,20 @@ DEFINE_STORE_CHECKED_LARGE_1(Accumulator_RowMajor, 1, 64) int sg_size = get_sub_group_size(); \ int pack_factor = contrib_bitwidth / elem_bitwidth; \ int col_sg_ratio = (sg_size * pack_factor) / K; \ + int M = (WI_rows * sg_size * pack_factor) / K; \ __private element_type *wi_contrib = (__private element_type *) dst; \ for (int i = 0; i < WI_rows; i++) { \ - element_type fill_value = slid % K < width - x && i * col_sg_ratio < height - y ? value : 0; \ - wi_contrib[i] = fill_value; \ + int row, col; \ + if (col_sg_ratio != 0) { \ + /* sg_size * pack_factor >= matrix width */ \ + row = slid / K + i * col_sg_ratio; \ + col = slid % K; \ + } else { \ + /* sg_size * pack_factor < matrix width */ \ + row = i % M; \ + col = (i / M) * sg_size + slid; \ + } \ + wi_contrib[i] = col < width - x && row < height - y ? value : 0; \ } \ } @@ -2337,10 +2350,11 @@ DEFINE_STORE_CHECKED_LARGE_1(Accumulator_RowMajor, 1, 64) DEFINE_FILLCHECKED_K(element_type, contrib_type, 8) \ DEFINE_FILLCHECKED_K(element_type, contrib_type, 16) \ DEFINE_FILLCHECKED_K(element_type, contrib_type, 32) \ + DEFINE_FILLCHECKED_K(element_type, contrib_type, 64) #define DEFINE_FILLCHECKED_GROUP(element_type) \ DEFINE_FILLCHECKED_CONTRIB(element_type, short) \ - DEFINE_FILLCHECKED_CONTRIB(element_type, int) \ + DEFINE_FILLCHECKED_CONTRIB(element_type, int) DEFINE_FILLCHECKED_GROUP(char) DEFINE_FILLCHECKED_GROUP(short)