Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm64/Sve: Implement Math's DotProduct* APIs #102218

Merged
merged 8 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions src/coreclr/jit/codegenarm64test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4998,8 +4998,9 @@ void CodeGen::genArm64EmitterUnitTestsSve()
INS_SCALABLE_OPTS_PREDICATE_MERGE); /* MOV <Pd>.B, <Pg>/M, <Pn>.B */

// IF_SVE_CZ_4A_L
theEmitter->emitIns_R_R(INS_sve_mov, EA_SCALABLE, REG_P0, REG_P15, INS_OPTS_SCALABLE_B); /* MOV <Pd>.B, <Pn>.B
*/
theEmitter->emitIns_Mov(INS_sve_mov, EA_SCALABLE, REG_P0, REG_P15, /* canSkip */ false,
INS_OPTS_SCALABLE_B); /* MOV <Pd>.B, <Pn>.B
*/

// IF_SVE_DA_4A
theEmitter->emitIns_R_R_R_R(INS_sve_brkpa, EA_SCALABLE, REG_P0, REG_P1, REG_P10, REG_P15,
Expand Down Expand Up @@ -6994,23 +6995,23 @@ void CodeGen::genArm64EmitterUnitTestsSve()

// IF_SVE_EY_3A
theEmitter->emitIns_R_R_R_I(INS_sve_sdot, EA_SCALABLE, REG_V9, REG_V10, REG_V4, 0,
INS_OPTS_SCALABLE_B); // SDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]
INS_OPTS_SCALABLE_S); // SDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_sdot, EA_SCALABLE, REG_V11, REG_V12, REG_V5, 1,
INS_OPTS_SCALABLE_B); // SDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]
INS_OPTS_SCALABLE_S); // SDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_udot, EA_SCALABLE, REG_V13, REG_V14, REG_V6, 2,
INS_OPTS_SCALABLE_B); // UDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]
INS_OPTS_SCALABLE_S); // UDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_udot, EA_SCALABLE, REG_V15, REG_V16, REG_V7, 3,
INS_OPTS_SCALABLE_B); // UDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]
INS_OPTS_SCALABLE_S); // UDOT <Zda>.S, <Zn>.B, <Zm>.B[<imm>]

// IF_SVE_EY_3B
theEmitter->emitIns_R_R_R_I(INS_sve_sdot, EA_SCALABLE, REG_V0, REG_V1, REG_V0,
0); // SDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_sdot, EA_SCALABLE, REG_V2, REG_V3, REG_V5,
1); // SDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_udot, EA_SCALABLE, REG_V4, REG_V5, REG_V10,
0); // UDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_udot, EA_SCALABLE, REG_V6, REG_V7, REG_V15,
1); // UDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_sdot, EA_SCALABLE, REG_V0, REG_V1, REG_V0, 0,
INS_OPTS_SCALABLE_D); // SDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_sdot, EA_SCALABLE, REG_V2, REG_V3, REG_V5, 1,
INS_OPTS_SCALABLE_D); // SDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_udot, EA_SCALABLE, REG_V4, REG_V5, REG_V10, 0,
INS_OPTS_SCALABLE_D); // UDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]
theEmitter->emitIns_R_R_R_I(INS_sve_udot, EA_SCALABLE, REG_V6, REG_V7, REG_V15, 1,
INS_OPTS_SCALABLE_D); // UDOT <Zda>.D, <Zn>.H, <Zm>.H[<imm>]

// IF_SVE_EZ_3A
theEmitter->emitIns_R_R_R_I(INS_sve_sudot, EA_SCALABLE, REG_V17, REG_V18, REG_V0, 0,
Expand Down Expand Up @@ -8892,15 +8893,15 @@ void CodeGen::genArm64EmitterUnitTestsSve()
// DUP <Zd>.<T>, <R><n|SP>
theEmitter->emitIns_R_R(INS_sve_dup, EA_8BYTE, REG_V4, REG_SP, INS_OPTS_SCALABLE_D);
// MOV <Zd>.<T>, <R><n|SP>
theEmitter->emitIns_R_R(INS_sve_mov, EA_4BYTE, REG_V4, REG_R2, INS_OPTS_SCALABLE_B);
theEmitter->emitIns_Mov(INS_sve_mov, EA_4BYTE, REG_V4, REG_R2, /* canSkip */ false, INS_OPTS_SCALABLE_B);
// MOV <Zd>.<T>, <R><n|SP>
theEmitter->emitIns_R_R(INS_sve_mov, EA_4BYTE, REG_V4, REG_R2, INS_OPTS_SCALABLE_H);
theEmitter->emitIns_Mov(INS_sve_mov, EA_4BYTE, REG_V4, REG_R2, /* canSkip */ false, INS_OPTS_SCALABLE_H);
// MOV <Zd>.<T>, <R><n|SP>
theEmitter->emitIns_R_R(INS_sve_mov, EA_4BYTE, REG_V1, REG_R3, INS_OPTS_SCALABLE_S);
theEmitter->emitIns_Mov(INS_sve_mov, EA_4BYTE, REG_V1, REG_R3, /* canSkip */ false, INS_OPTS_SCALABLE_S);
// MOV <Zd>.<T>, <R><n|SP>
theEmitter->emitIns_R_R(INS_sve_mov, EA_8BYTE, REG_V5, REG_SP, INS_OPTS_SCALABLE_D);
theEmitter->emitIns_Mov(INS_sve_mov, EA_8BYTE, REG_V5, REG_SP, /* canSkip */ false, INS_OPTS_SCALABLE_D);
// MOV <Zd>.<T>, <R><n|SP>
theEmitter->emitIns_R_R(INS_sve_mov, EA_8BYTE, REG_V2, REG_R9, INS_OPTS_SCALABLE_D);
theEmitter->emitIns_Mov(INS_sve_mov, EA_8BYTE, REG_V2, REG_R9, /* canSkip */ false, INS_OPTS_SCALABLE_D);

// IF_SVE_BJ_2A
// FEXPA <Zd>.<T>, <Zn>.<T>
Expand Down
10 changes: 10 additions & 0 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4275,6 +4275,16 @@ void emitter::emitIns_Mov(
}
fmt = IF_SVE_AU_3A;
}
else if (isVectorRegister(dstReg) && isGeneralRegisterOrSP(srcReg))
{
assert(insOptsScalable(opt));
if (IsRedundantMov(ins, size, dstReg, srcReg, canSkip))
{
return;
}
srcReg = encodingSPtoZR(srcReg);
fmt = IF_SVE_CB_2A;
}
else
{
unreached();
Expand Down
13 changes: 6 additions & 7 deletions src/coreclr/jit/emitarm64sve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4494,21 +4494,20 @@ void emitter::emitInsSve_R_R_R_I(instruction ins,
assert(isVectorRegister(reg2)); // nnnnn
assert(isLowVectorRegister(reg3)); // mmmm

if (opt == INS_OPTS_SCALABLE_B)
if (opt == INS_OPTS_SCALABLE_H)
{
assert((REG_V0 <= reg3) && (reg3 <= REG_V7)); // mmm
assert(isValidUimm<2>(imm)); // ii
fmt = IF_SVE_EY_3A;
assert(isValidUimm<2>(imm)); // ii
fmt = IF_SVE_EG_3A;
}
else if (opt == INS_OPTS_SCALABLE_H)
else if (opt == INS_OPTS_SCALABLE_S)
{
assert((REG_V0 <= reg3) && (reg3 <= REG_V7)); // mmm
assert(isValidUimm<2>(imm)); // ii
fmt = IF_SVE_EG_3A;
fmt = IF_SVE_EY_3A;
}
else
{
assert(insOptsNone(opt));
assert(opt == INS_OPTS_SCALABLE_D);
assert(isValidUimm<1>(imm)); // i
opt = INS_OPTS_SCALABLE_H;
fmt = IF_SVE_EY_3B;
Expand Down
11 changes: 11 additions & 0 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic is a FusedMultiplyAdd intrinsic
HW_Flag_FmaIntrinsic = 0x20000000,

#if defined(TARGET_ARM64)
// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low vector register.
HW_Flag_LowVectorOperation = 0x4000000,
#endif

HW_Flag_CanBenefitFromConstantProp = 0x80000000,
};

Expand Down Expand Up @@ -891,6 +896,12 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_LowMaskedOperation) != 0;
}

static bool IsLowVectorOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_LowVectorOperation) != 0;
}

static bool IsOptionalEmbeddedMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
Expand Down
11 changes: 11 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,17 @@ void Compiler::getHWIntrinsicImmTypes(NamedIntrinsic intrinsic,
// The second source operand of sdot, udot instructions is an indexed 32-bit element.
indexedElementBaseType = simdBaseType;
}

if (intrinsic == NI_Sve_DotProductBySelectedScalar)
{
assert(((simdBaseType == TYP_INT) && (indexedElementBaseType == TYP_BYTE)) ||
((simdBaseType == TYP_UINT) && (indexedElementBaseType == TYP_UBYTE)) ||
((simdBaseType == TYP_LONG) && (indexedElementBaseType == TYP_SHORT)) ||
((simdBaseType == TYP_ULONG) && (indexedElementBaseType == TYP_USHORT)));

// The second source operand of sdot, udot instructions is an indexed 32-bit element.
indexedElementBaseType = simdBaseType;
}
}

assert(indexedElementBaseType == simdBaseType);
Expand Down
8 changes: 5 additions & 3 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask32Bit,
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask64Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask8Bit, -1, 2, false, {INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, Divide, -1, 2, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdiv, INS_sve_udiv, INS_sve_sdiv, INS_sve_udiv, INS_sve_fdiv, INS_sve_fdiv}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, DotProduct, -1, 3, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdot, INS_sve_udot, INS_sve_sdot, INS_sve_udot, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, DotProductBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdot, INS_sve_udot, INS_sve_sdot, INS_sve_udot, INS_invalid, INS_invalid}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplyAdd, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplyAddBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic)
HARDWARE_INTRINSIC(Sve, FusedMultiplyAddBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplyAddNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmla, INS_sve_fnmla}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtract, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmls, INS_sve_fnmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, LoadVector, -1, 2, true, {INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1h, INS_sve_ld1h, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, LoadVectorByteZeroExtendToInt16, -1, 2, false, {INS_invalid, INS_invalid, INS_sve_ld1b, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down Expand Up @@ -101,7 +103,7 @@ HARDWARE_INTRINSIC(Sve, MinNumber,
HARDWARE_INTRINSIC(Sve, MinNumberAcross, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fminnmv, INS_sve_fminnmv}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation)
HARDWARE_INTRINSIC(Sve, Multiply, -1, 2, true, {INS_sve_mul, INS_sve_mul, INS_sve_mul, INS_sve_mul, INS_sve_mul, INS_sve_mul, INS_sve_mul, INS_sve_mul, INS_sve_fmul, INS_sve_fmul}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_OptionalEmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, MultiplyAdd, -1, -1, false, {INS_sve_mla, INS_sve_mla, INS_sve_mla, INS_sve_mla, INS_sve_mla, INS_sve_mla, INS_sve_mla, INS_sve_mla, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, MultiplyBySelectedScalar, -1, 3, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmul, INS_sve_fmul}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand)
HARDWARE_INTRINSIC(Sve, MultiplyBySelectedScalar, -1, 3, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmul, INS_sve_fmul}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, MultiplyExtended, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmulx, INS_sve_fmulx}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, MultiplySubtract, -1, -1, false, {INS_sve_mls, INS_sve_mls, INS_sve_mls, INS_sve_mls, INS_sve_mls, INS_sve_mls, INS_sve_mls, INS_sve_mls, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, Negate, -1, -1, false, {INS_sve_neg, INS_invalid, INS_sve_neg, INS_invalid, INS_sve_neg, INS_invalid, INS_sve_neg, INS_invalid, INS_sve_fneg, INS_sve_fneg}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,7 @@ class LinearScan : public LinearScanInterface
return nextConsecutiveRefPositionMap;
}
FORCEINLINE RefPosition* getNextConsecutiveRefPosition(RefPosition* refPosition);
void getLowVectorOperandAndCandidates(HWIntrinsic intrin, size_t* operandNum, regMaskTP* candidates);
#endif

#ifdef DEBUG
Expand Down
Loading
Loading