Neo: Float16 packed math (#4354)

This commit is contained in:
Marcin Mikołajczyk 2026-05-05 00:21:20 +02:00 committed by GitHub
parent 26eaa3e3af
commit 34b35b526e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 288 additions and 0 deletions

View File

@ -768,6 +768,134 @@ T Translator::GetSrc64(const InstOperand& operand) {
template IR::U64 Translator::GetSrc64<IR::U64>(const InstOperand&);
template IR::F64 Translator::GetSrc64<IR::F64>(const InstOperand&);
template <typename T>
pk_type<T> Translator::GetSrcPk(const InstOperand& operand) {
constexpr bool is_float = std::is_same_v<T, IR::F32>;
const auto get_imm = [&](auto value) -> pk_type<T> {
if constexpr (is_float) {
auto imm = ir.Imm32(std::bit_cast<float>(value));
return {operand.op_sel.op_sel ? ir.Imm32(0.f) : imm,
operand.op_sel.op_sel_hi ? ir.Imm32(0.f) : imm};
} else {
auto imm = ir.Imm32(std::bit_cast<u32>(value));
return {operand.op_sel.op_sel ? ir.Imm32(0U) : imm,
operand.op_sel.op_sel_hi ? ir.Imm32(0U) : imm};
}
};
constexpr auto number_format = [&]() {
if constexpr (is_float) {
return AmdGpu::NumberFormat::Float;
} else {
return AmdGpu::NumberFormat::Uint;
}
}();
const auto cast = [&](auto value) -> T {
if constexpr (is_float) {
return value;
} else {
return ir.BitCast<IR::U32>(value);
}
};
const auto extract = [&](auto value) -> pk_type<T> {
auto v_unpacked = ir.Unpack2x16(number_format, value);
return {cast(IR::F32{ir.CompositeExtract(v_unpacked, operand.op_sel.op_sel)}),
cast(IR::F32{ir.CompositeExtract(v_unpacked, operand.op_sel.op_sel_hi)})};
};
pk_type<T> value{};
switch (operand.field) {
case OperandField::ScalarGPR: {
value = extract(ir.GetScalarReg<IR::U32>(IR::ScalarReg(operand.code)));
break;
}
case OperandField::VectorGPR: {
value = extract(ir.GetVectorReg<IR::U32>(IR::VectorReg(operand.code)));
break;
}
case OperandField::ConstZero: {
value = get_imm(0U);
break;
}
case OperandField::SignedConstIntPos: {
value = get_imm(operand.code - SignedConstIntPosMin + 1);
break;
}
case OperandField::SignedConstIntNeg: {
value = get_imm(-s32(operand.code) + SignedConstIntNegMin - 1);
break;
}
case OperandField::LiteralConst: {
value = get_imm(operand.code);
break;
}
case OperandField::ConstFloatPos_1_0: {
value = get_imm(1.f);
break;
}
case OperandField::ConstFloatPos_0_5: {
value = get_imm(0.5f);
break;
}
case OperandField::ConstFloatPos_2_0: {
value = get_imm(2.0f);
break;
}
case OperandField::ConstFloatPos_4_0: {
value = get_imm(4.0f);
break;
}
case OperandField::ConstFloatNeg_0_5: {
value = get_imm(-0.5f);
break;
}
case OperandField::ConstFloatNeg_1_0: {
value = get_imm(-1.0f);
break;
}
case OperandField::ConstFloatNeg_2_0: {
value = get_imm(-2.0f);
break;
}
case OperandField::ConstFloatNeg_4_0: {
value = get_imm(-4.0f);
break;
}
case OperandField::Inv2Pi: {
value = get_imm(1.0f / (2.0f * std::numbers::pi_v<float>));
break;
}
case OperandField::VccLo:
value = extract(ir.GetVccLo());
break;
default:
UNREACHABLE_MSG("unexpected operand: {}", std::to_underlying(operand.field));
}
if constexpr (is_float) {
if (operand.input_modifier.neg) {
value.first = ir.FPNeg(value.first);
}
if (operand.input_modifier.neg_hi) {
value.second = ir.FPNeg(value.second);
}
} else {
if (operand.input_modifier.neg) {
value.first = ir.INeg(value.first);
}
if (operand.input_modifier.neg_hi) {
value.second = ir.INeg(value.second);
}
}
return value;
}
template pk_type<IR::U32> Translator::GetSrcPk<IR::U32>(const InstOperand&);
template pk_type<IR::F32> Translator::GetSrcPk<IR::F32>(const InstOperand&);
void Translator::SetDst1(const InstOperand& operand, const IR::U1& value) {
switch (operand.field) {
case OperandField::VccLo:

View File

@ -54,6 +54,9 @@ enum class NegateMode : u32 {
Result,
};
template <typename T>
using pk_type = std::pair<T, T>;
static constexpr size_t MaxInterpVgpr = 16;
class Translator {
@ -286,6 +289,11 @@ public:
void V_OR3_B32(const GcnInst& inst);
// VOP3P
void V_PK_FMA_F16(const GcnInst& inst);
void V_PK_ADD_F16(const GcnInst& inst);
void V_PK_MUL_F16(const GcnInst& inst);
void V_PK_MIN_F16(const GcnInst& inst);
void V_PK_MAX_F16(const GcnInst& inst);
void V_MAD_MIX_F32(const GcnInst& inst);
void V_MAD_MIXLO_F16(const GcnInst& inst);
void V_MAD_MIXHI_F16(const GcnInst& inst);
@ -334,6 +342,8 @@ private:
template <typename T = IR::U64>
[[nodiscard]] T GetSrc64(const InstOperand& operand);
[[nodiscard]] IR::F32 GetSrcMix(const InstOperand& operand);
template <typename T = IR::U32>
[[nodiscard]] pk_type<T> GetSrcPk(const InstOperand& operand);
void SetDst1(const InstOperand& operand, const IR::U1& value);
void SetDst(const InstOperand& operand, const IR::U32F32& value);
template <bool is_signed = false>

View File

@ -478,6 +478,16 @@ void Translator::EmitVectorAlu(const GcnInst& inst) {
return;
// VOP3P
case Opcode::V_PK_FMA_F16:
return V_PK_FMA_F16(inst);
case Opcode::V_PK_ADD_F16:
return V_PK_ADD_F16(inst);
case Opcode::V_PK_MUL_F16:
return V_PK_MUL_F16(inst);
case Opcode::V_PK_MIN_F16:
return V_PK_MIN_F16(inst);
case Opcode::V_PK_MAX_F16:
return V_PK_MAX_F16(inst);
case Opcode::V_MAD_MIX_F32:
return V_MAD_MIX_F32(inst);
case Opcode::V_MAD_MIXLO_F16:
@ -1634,6 +1644,62 @@ void Translator::V_ADD3_U32(const GcnInst& inst) {
SetDst(inst.dst[0], ir.IAdd(src0, ir.IAdd(src1, src2)));
}
void Translator::V_PK_FMA_F16(const GcnInst& inst) {
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
const auto src2 = GetSrcPk<IR::F32>(inst.src[2]);
const auto result_lo = ir.FPFma(src0.first, src1.first, src2.first);
const auto result_hi = ir.FPFma(src0.second, src1.second, src2.second);
SetDst(inst.dst[0],
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
}
void Translator::V_PK_ADD_F16(const GcnInst& inst) {
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
const auto result_lo = ir.FPAdd(src0.first, src1.first);
const auto result_hi = ir.FPAdd(src0.second, src1.second);
SetDst(inst.dst[0],
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
}
void Translator::V_PK_MUL_F16(const GcnInst& inst) {
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
const auto result_lo = ir.FPMul(src0.first, src1.first);
const auto result_hi = ir.FPMul(src0.second, src1.second);
SetDst(inst.dst[0],
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
}
void Translator::V_PK_MIN_F16(const GcnInst& inst) {
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
const auto result_lo = ir.FPMin(src0.first, src1.first);
const auto result_hi = ir.FPMin(src0.second, src1.second);
SetDst(inst.dst[0],
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
}
void Translator::V_PK_MAX_F16(const GcnInst& inst) {
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
const auto result_lo = ir.FPMax(src0.first, src1.first);
const auto result_hi = ir.FPMax(src0.second, src1.second);
SetDst(inst.dst[0],
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
}
void Translator::V_LSHL_OR_B32(const GcnInst& inst) {
const auto src0 = GetSrc<IR::U32>(inst.src[0]);
const auto src1 = GetSrc<IR::U32>(inst.src[1]);

View File

@ -424,3 +424,87 @@ TEST_F(GcnTest, ashrrev_i16_1) {
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 0x1234F111);
}
TEST_F(GcnTest, pk_add_f16_1) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, (F16x2{half(4.0f), half(6.0f)}));
}
TEST_F(GcnTest, pk_add_f16_2) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<u32>(spirv, 0U);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 0x00003118);
}
TEST_F(GcnTest, pk_add_f16_3) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).SetOpSel({0,1,1}).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<u32>(spirv, 0U);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 0);
}
TEST_F(GcnTest, pk_add_f16_4) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0p5, SOperand9::Const0p5).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<u32>(spirv, 0U);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 0x3C00);
}
TEST_F(GcnTest, pk_add_f16_5) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).SetOpSelHi({0,0,0}).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<u32>(spirv, 0U);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 0x31183118);
}
TEST_F(GcnTest, pk_add_f16_neg_lo) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNeg({1,1,0}).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, (F16x2{half(-4.0f), half(6.0f)}));
}
TEST_F(GcnTest, pk_add_f16_neg_hi) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNegHi({1,1,0}).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, (F16x2{half(4.0f), half(-6.0f)}));
}
TEST_F(GcnTest, pk_add_f16_op_sel_reversed) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetOpSel({1,1,1}).SetOpSelHi({0,0,0}).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, (F16x2{half(6.0f), half(4.0f)}));
}