Skip to content

Commit

Permalink
Joint Matrix: fix Fill Checked and add MAD big shapes
Browse files Browse the repository at this point in the history
- Implemented FIll Checked built-in for N=64 (e.g. 32x64x16 shape)
- Implemented MAD built-ins for big shapes for half type
  • Loading branch information
YuriPlyakhin authored and igcbot committed Feb 25, 2025
1 parent 6773ce3 commit 5265304
Showing 1 changed file with 165 additions and 151 deletions.
316 changes: 165 additions & 151 deletions IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,9 @@ DEFINE_GET_COORD(Accumulator, , 32, 32, 8, 8, 1)

/* experimental large slice support: */

#define DEFINE_MAD_16x16x16_IMPL(a_type, b_type, a_suffix, b_suffix) \
// MAD:

#define DEFINE_MAD_LARGE_SLICE_16bit_AB_IMPL(a_type, b_type, a_suffix, b_suffix) \
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) { \
short16 a = *(short16 *)a_ptr; \
int8 b = *(int8 *)b_ptr; \
Expand All @@ -1463,156 +1465,157 @@ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##
__private int16 *dst = (__private int16 *)result; \
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7, \
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7); \
} \
\
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x32x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \
int8 a0 = *(int8 *)a_ptr; \
int8 a1 = *(int8 *) (a_ptr + 1 * 16 * (sizeof (short))); \
int8 a2 = *(int8 *) (a_ptr + 2 * 16 * (sizeof (short))); \
int8 a3 = *(int8 *) (a_ptr + 3 * 16 * (sizeof (short))); \
\
int8 b0 = *(int8 *)b_ptr; \
int8 b1 = *(int8 *) (b_ptr + 1 * 16 * (sizeof (short))); \
int8 b2 = *(int8 *) (b_ptr + 2 * 16 * (sizeof (short))); \
int8 b3 = *(int8 *) (b_ptr + 3 * 16 * (sizeof (short))); \
\
float8 c0 = *(float8 *) (c_ptr + 0 * 8 * (sizeof (int))); \
float8 c1 = *(float8 *) (c_ptr + 4 * 8 * (sizeof (int))); \
float8 c2 = *(float8 *) (c_ptr + 8 * 8 * (sizeof (int))); \
float8 c3 = *(float8 *) (c_ptr + 12 * 8 * (sizeof (int))); \
float8 c4 = *(float8 *) (c_ptr + 1 * 8 * (sizeof (int))); \
float8 c5 = *(float8 *) (c_ptr + 5 * 8 * (sizeof (int))); \
float8 c6 = *(float8 *) (c_ptr + 9 * 8 * (sizeof (int))); \
float8 c7 = *(float8 *) (c_ptr + 13 * 8 * (sizeof (int))); \
float8 c8 = *(float8 *) (c_ptr + 2 * 8 * (sizeof (int))); \
float8 c9 = *(float8 *) (c_ptr + 6 * 8 * (sizeof (int))); \
float8 c10 = *(float8 *) (c_ptr + 10 * 8 * (sizeof (int))); \
float8 c11 = *(float8 *) (c_ptr + 14 * 8 * (sizeof (int))); \
float8 c12 = *(float8 *) (c_ptr + 3 * 8 * (sizeof (int))); \
float8 c13 = *(float8 *) (c_ptr + 7 * 8 * (sizeof (int))); \
float8 c14 = *(float8 *) (c_ptr + 11 * 8 * (sizeof (int))); \
float8 c15 = *(float8 *) (c_ptr + 15 * 8 * (sizeof (int))); \
\
*(float8 *) (d_ptr + 0 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c0, a0, b0); \
*(float8 *) (d_ptr + 4 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c1, a0, b1); \
*(float8 *) (d_ptr + 8 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c2, a0, b2); \
*(float8 *) (d_ptr + 12 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c3, a0, b3); \
*(float8 *) (d_ptr + 1 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c4, a1, b0); \
*(float8 *) (d_ptr + 5 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c5, a1, b1); \
*(float8 *) (d_ptr + 9 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c6, a1, b2); \
*(float8 *) (d_ptr + 13 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c7, a1, b3); \
*(float8 *) (d_ptr + 2 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c8, a2, b0); \
*(float8 *) (d_ptr + 6 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c9, a2, b1); \
*(float8 *) (d_ptr + 10 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c10, a2, b2); \
*(float8 *) (d_ptr + 14 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c11, a2, b3); \
*(float8 *) (d_ptr + 3 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c12, a3, b0); \
*(float8 *) (d_ptr + 7 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c13, a3, b1); \
*(float8 *) (d_ptr + 11 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c14, a3, b2); \
*(float8 *) (d_ptr + 15 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_##a_suffix##_##b_suffix##_8_8(c15, a3, b3); \
} \
\
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \
short a = *(short *) a_ptr; \
\
int8 b0 = *(int8 *) b_ptr; \
int8 b1 = *(int8 *)(b_ptr + 1 * 16 * (sizeof (short))); \
int8 b2 = *(int8 *)(b_ptr + 2 * 16 * (sizeof (short))); \
int8 b3 = *(int8 *)(b_ptr + 3 * 16 * (sizeof (short))); \
\
float c0 = *(float *) c_ptr; \
float c1 = *(float *) (c_ptr + 1 * (sizeof (int))); \
float c2 = *(float *) (c_ptr + 2 * (sizeof (int))); \
float c3 = *(float *) (c_ptr + 3 * (sizeof (int))); \
\
float d0 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c0, a, b0); \
float d1 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c1, a, b1); \
float d2 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c2, a, b2); \
float d3 = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c3, a, b3); \
\
__private int4 *dst = (__private int4 *)d_ptr; \
*dst = (int4)(as_int(d0), as_int(d1), as_int(d2), as_int(d3)); \
} \
\
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \
__private char *a0 = a_ptr; \
__private char *a1 = a_ptr + 16 * (sizeof (short)); \
\
__private char *b0 = b_ptr; \
__private char *b1 = b_ptr + 1 * 16 * (sizeof (short)); \
__private char *b2 = b_ptr + 2 * 16 * (sizeof (short)); \
__private char *b3 = b_ptr + 3 * 16 * (sizeof (short)); \
\
__private char *c0 = c_ptr + 0 * 16 * (sizeof (int)); \
__private char *c1 = c_ptr + 2 * 16 * (sizeof (int)); \
__private char *c2 = c_ptr + 4 * 16 * (sizeof (int)); \
__private char *c3 = c_ptr + 6 * 16 * (sizeof (int)); \
__private char *c4 = c_ptr + 1 * 16 * (sizeof (int)); \
__private char *c5 = c_ptr + 3 * 16 * (sizeof (int)); \
__private char *c6 = c_ptr + 5 * 16 * (sizeof (int)); \
__private char *c7 = c_ptr + 7 * 16 * (sizeof (int)); \
\
__private char *d0 = d_ptr + 0 * 16 * (sizeof (int)); \
__private char *d1 = d_ptr + 2 * 16 * (sizeof (int)); \
__private char *d2 = d_ptr + 4 * 16 * (sizeof (int)); \
__private char *d3 = d_ptr + 6 * 16 * (sizeof (int)); \
__private char *d4 = d_ptr + 1 * 16 * (sizeof (int)); \
__private char *d5 = d_ptr + 3 * 16 * (sizeof (int)); \
__private char *d6 = d_ptr + 5 * 16 * (sizeof (int)); \
__private char *d7 = d_ptr + 7 * 16 * (sizeof (int)); \
\
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b0, c0, d0); \
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b1, c1, d1); \
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b2, c2, d2); \
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a0, b3, c3, d3); \
\
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b0, c4, d4); \
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b1, c5, d5); \
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b2, c6, d6); \
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_##a_type##_##b_type##_fp32(a1, b3, c7, d7); \
} \
\
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x32_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \
short8 a[8]; \
int8 b[8]; \
for (int i = 0; i < 8; i++) { \
a[i] = *(short8 *)(a_ptr + i * 8 * (sizeof (short))); \
b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int))); \
} \
\
float8 c[16]; \
for (int i = 0; i < 16; i++) \
c[i] = *(float8 *)(c_ptr + i * 8 * (sizeof (int))); \
\
_Pragma("unroll") /* TODO: investigate, why not unrolling the loop causes wrong code generated*/ \
for (int i = 0; i < 4; i++) { \
for (int j = 0; j < 4; j++) { \
float8 d = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(c[i + 4*j], a[i], b[2*j]); \
*(float8 *)(d_ptr + (i + 4*j) * 8 * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_8(d, a[i + 4], b[2*j + 1]); \
} \
} \
} \
\
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x32_##a_type##_##b_type##_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) { \
short a0 = *(short *)(a_ptr + 0 * (sizeof (short))); \
short a1 = *(short *)(a_ptr + 1 * (sizeof (short))); \
\
int8 b[8]; \
for (int i = 0; i < 8; i++) \
b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int))); \
\
float c[4]; \
for (int i = 0; i < 4; i++) \
c[i] = *(float *)(c_ptr + i * (sizeof (int))); \
\
for (int i = 0; i < 4; i++) { \
float d = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(c[i], a0, b[2 * i]); \
*(float *)(d_ptr + i * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_##a_suffix##_##b_suffix##_8_1(d, a1, b[2 * i + 1]); \
} \
}

DEFINE_MAD_16x16x16_IMPL(bf16, bf16, bf, bf)
DEFINE_MAD_16x16x16_IMPL(fp16, fp16, hf, hf)

// Splitting a 32x32x16 MAD operation into sixteen of 8x8x8 MAD operations
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x32x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
int8 a0 = *(int8 *)a_ptr;
int8 a1 = *(int8 *) (a_ptr + 1 * 16 * (sizeof (short)));
int8 a2 = *(int8 *) (a_ptr + 2 * 16 * (sizeof (short)));
int8 a3 = *(int8 *) (a_ptr + 3 * 16 * (sizeof (short)));

int8 b0 = *(int8 *)b_ptr;
int8 b1 = *(int8 *) (b_ptr + 1 * 16 * (sizeof (short)));
int8 b2 = *(int8 *) (b_ptr + 2 * 16 * (sizeof (short)));
int8 b3 = *(int8 *) (b_ptr + 3 * 16 * (sizeof (short)));

float8 c0 = *(float8 *) (c_ptr + 0 * 8 * (sizeof (int)));
float8 c1 = *(float8 *) (c_ptr + 4 * 8 * (sizeof (int)));
float8 c2 = *(float8 *) (c_ptr + 8 * 8 * (sizeof (int)));
float8 c3 = *(float8 *) (c_ptr + 12 * 8 * (sizeof (int)));
float8 c4 = *(float8 *) (c_ptr + 1 * 8 * (sizeof (int)));
float8 c5 = *(float8 *) (c_ptr + 5 * 8 * (sizeof (int)));
float8 c6 = *(float8 *) (c_ptr + 9 * 8 * (sizeof (int)));
float8 c7 = *(float8 *) (c_ptr + 13 * 8 * (sizeof (int)));
float8 c8 = *(float8 *) (c_ptr + 2 * 8 * (sizeof (int)));
float8 c9 = *(float8 *) (c_ptr + 6 * 8 * (sizeof (int)));
float8 c10 = *(float8 *) (c_ptr + 10 * 8 * (sizeof (int)));
float8 c11 = *(float8 *) (c_ptr + 14 * 8 * (sizeof (int)));
float8 c12 = *(float8 *) (c_ptr + 3 * 8 * (sizeof (int)));
float8 c13 = *(float8 *) (c_ptr + 7 * 8 * (sizeof (int)));
float8 c14 = *(float8 *) (c_ptr + 11 * 8 * (sizeof (int)));
float8 c15 = *(float8 *) (c_ptr + 15 * 8 * (sizeof (int)));

*(float8 *) (d_ptr + 0 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c0, a0, b0);
*(float8 *) (d_ptr + 4 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c1, a0, b1);
*(float8 *) (d_ptr + 8 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c2, a0, b2);
*(float8 *) (d_ptr + 12 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c3, a0, b3);
*(float8 *) (d_ptr + 1 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c4, a1, b0);
*(float8 *) (d_ptr + 5 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c5, a1, b1);
*(float8 *) (d_ptr + 9 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c6, a1, b2);
*(float8 *) (d_ptr + 13 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c7, a1, b3);
*(float8 *) (d_ptr + 2 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c8, a2, b0);
*(float8 *) (d_ptr + 6 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c9, a2, b1);
*(float8 *) (d_ptr + 10 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c10, a2, b2);
*(float8 *) (d_ptr + 14 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c11, a2, b3);
*(float8 *) (d_ptr + 3 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c12, a3, b0);
*(float8 *) (d_ptr + 7 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c13, a3, b1);
*(float8 *) (d_ptr + 11 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c14, a3, b2);
*(float8 *) (d_ptr + 15 * 8 * (sizeof (float))) = __builtin_IB_sub_group_fdpas_bf_bf_8_8(c15, a3, b3);
}

INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
short a = *(short *) a_ptr;

int8 b0 = *(int8 *) b_ptr;
int8 b1 = *(int8 *)(b_ptr + 1 * 16 * (sizeof (short)));
int8 b2 = *(int8 *)(b_ptr + 2 * 16 * (sizeof (short)));
int8 b3 = *(int8 *)(b_ptr + 3 * 16 * (sizeof (short)));

float c0 = *(float *) c_ptr;
float c1 = *(float *) (c_ptr + 1 * (sizeof (int)));
float c2 = *(float *) (c_ptr + 2 * (sizeof (int)));
float c3 = *(float *) (c_ptr + 3 * (sizeof (int)));

float d0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c0, a, b0);
float d1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c1, a, b1);
float d2 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c2, a, b2);
float d3 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c3, a, b3);

__private int4 *dst = (__private int4 *)d_ptr;
*dst = (int4)(as_int(d0), as_int(d1), as_int(d2), as_int(d3));
}

INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
__private char *a0 = a_ptr;
__private char *a1 = a_ptr + 16 * (sizeof (short));

__private char *b0 = b_ptr;
__private char *b1 = b_ptr + 1 * 16 * (sizeof (short));
__private char *b2 = b_ptr + 2 * 16 * (sizeof (short));
__private char *b3 = b_ptr + 3 * 16 * (sizeof (short));

__private char *c0 = c_ptr + 0 * 16 * (sizeof (int));
__private char *c1 = c_ptr + 2 * 16 * (sizeof (int));
__private char *c2 = c_ptr + 4 * 16 * (sizeof (int));
__private char *c3 = c_ptr + 6 * 16 * (sizeof (int));
__private char *c4 = c_ptr + 1 * 16 * (sizeof (int));
__private char *c5 = c_ptr + 3 * 16 * (sizeof (int));
__private char *c6 = c_ptr + 5 * 16 * (sizeof (int));
__private char *c7 = c_ptr + 7 * 16 * (sizeof (int));

__private char *d0 = d_ptr + 0 * 16 * (sizeof (int));
__private char *d1 = d_ptr + 2 * 16 * (sizeof (int));
__private char *d2 = d_ptr + 4 * 16 * (sizeof (int));
__private char *d3 = d_ptr + 6 * 16 * (sizeof (int));
__private char *d4 = d_ptr + 1 * 16 * (sizeof (int));
__private char *d5 = d_ptr + 3 * 16 * (sizeof (int));
__private char *d6 = d_ptr + 5 * 16 * (sizeof (int));
__private char *d7 = d_ptr + 7 * 16 * (sizeof (int));

__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b0, c0, d0);
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b1, c1, d1);
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b2, c2, d2);
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b3, c3, d3);

__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b0, c4, d4);
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b1, c5, d5);
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b2, c6, d6);
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b3, c7, d7);
}

INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x32_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
short8 a[8];
int8 b[8];
for (int i = 0; i < 8; i++) {
a[i] = *(short8 *)(a_ptr + i * 8 * (sizeof (short)));
b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int)));
}

float8 c[16];
for (int i = 0; i < 16; i++)
c[i] = *(float8 *)(c_ptr + i * 8 * (sizeof (int)));

#pragma unroll // TODO: investigate, why not unrolling the loop causes wrong code generated
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
float8 d = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c[i + 4*j], a[i], b[2*j]);
*(float8 *)(d_ptr + (i + 4*j) * 8 * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(d, a[i + 4], b[2*j + 1]);
}
}
}

INLINE void __builtin_spriv_OpJointMatrixMadINTEL_1x64x32_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
short a0 = *(short *)(a_ptr + 0 * (sizeof (short)));
short a1 = *(short *)(a_ptr + 1 * (sizeof (short)));
DEFINE_MAD_LARGE_SLICE_16bit_AB_IMPL(bf16, bf16, bf, bf)
DEFINE_MAD_LARGE_SLICE_16bit_AB_IMPL(fp16, fp16, hf, hf)

int8 b[8];
for (int i = 0; i < 8; i++)
b[i] = *(int8 *)(b_ptr + i * 8 * (sizeof (int)));

float c[4];
for (int i = 0; i < 4; i++)
c[i] = *(float *)(c_ptr + i * (sizeof (int)));

for (int i = 0; i < 4; i++) {
float d = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(c[i], a0, b[2 * i]);
*(float *)(d_ptr + i * (sizeof (float))) = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_1(d, a1, b[2 * i + 1]);
}
}
// LOADS:

/* PackedA load i16 for big shapes */
DEFINE_LOAD(PackedA_RowMajor, , short, int, 32, 16, ROW_MAJOR, , 32)
Expand Down Expand Up @@ -2310,10 +2313,20 @@ DEFINE_STORE_CHECKED_LARGE_1(Accumulator_RowMajor, 1, 64)
int sg_size = get_sub_group_size(); \
int pack_factor = contrib_bitwidth / elem_bitwidth; \
int col_sg_ratio = (sg_size * pack_factor) / K; \
int M = (WI_rows * sg_size * pack_factor) / K; \
__private element_type *wi_contrib = (__private element_type *) dst; \
for (int i = 0; i < WI_rows; i++) { \
element_type fill_value = slid % K < width - x && i * col_sg_ratio < height - y ? value : 0; \
wi_contrib[i] = fill_value; \
int row, col; \
if (col_sg_ratio != 0) { \
/* sg_size * pack_factor >= matrix width */ \
row = slid / K + i * col_sg_ratio; \
col = slid % K; \
} else { \
/* sg_size * pack_factor < matrix width */ \
row = i % M; \
col = (i / M) * sg_size + slid; \
} \
wi_contrib[i] = col < width - x && row < height - y ? value : 0; \
} \
}

Expand All @@ -2337,10 +2350,11 @@ DEFINE_STORE_CHECKED_LARGE_1(Accumulator_RowMajor, 1, 64)
DEFINE_FILLCHECKED_K(element_type, contrib_type, 8) \
DEFINE_FILLCHECKED_K(element_type, contrib_type, 16) \
DEFINE_FILLCHECKED_K(element_type, contrib_type, 32) \
DEFINE_FILLCHECKED_K(element_type, contrib_type, 64)

#define DEFINE_FILLCHECKED_GROUP(element_type) \
DEFINE_FILLCHECKED_CONTRIB(element_type, short) \
DEFINE_FILLCHECKED_CONTRIB(element_type, int) \
DEFINE_FILLCHECKED_CONTRIB(element_type, int)

DEFINE_FILLCHECKED_GROUP(char)
DEFINE_FILLCHECKED_GROUP(short)
Expand Down

0 comments on commit 5265304

Please sign in to comment.