Neo: V_MAD_MIX opcodes (#4338)

This commit is contained in:
Marcin Mikołajczyk 2026-05-01 01:56:27 +02:00 committed by GitHub
parent f90b89ce61
commit a3e25efad5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 220 additions and 0 deletions

View File

@ -540,6 +540,121 @@ template IR::U32 Translator::GetSrc16<IR::U32, false>(const InstOperand&);
template IR::U32 Translator::GetSrc16<IR::U32, true>(const InstOperand&);
template IR::F32 Translator::GetSrc16<IR::F32, false>(const InstOperand&);
IR::F32 Translator::GetSrcMix(const InstOperand& operand) {
const auto get_imm = [&](auto value) -> IR::F32 {
return ir.Imm32(std::bit_cast<float>(value));
};
const auto extract = [&](auto value) -> IR::F32 {
const auto getter_u = [&]() {
if constexpr (std::same_as<decltype(value), IR::ScalarReg>) {
return ir.GetScalarReg<IR::U32>(value);
} else {
return ir.GetVectorReg<IR::U32>(value);
}
}();
if (!operand.op_sel.op_sel_hi) {
if constexpr (std::same_as<decltype(value), IR::ScalarReg>) {
return ir.GetScalarReg<IR::F32>(value);
} else {
return ir.GetVectorReg<IR::F32>(value);
}
} else if (operand.op_sel.op_sel) {
return IR::F32{
ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, getter_u), 1)};
} else {
return IR::F32{
ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, getter_u), 0)};
}
};
IR::F32 value{};
switch (operand.field) {
case OperandField::ScalarGPR:
value = extract(IR::ScalarReg(operand.code));
break;
case OperandField::VectorGPR:
value = extract(IR::VectorReg(operand.code));
break;
case OperandField::ConstZero:
value = get_imm(0U);
break;
case OperandField::SignedConstIntPos:
value = get_imm(operand.code - SignedConstIntPosMin + 1);
break;
case OperandField::SignedConstIntNeg:
value = get_imm(-s32(operand.code) + SignedConstIntNegMin - 1);
break;
case OperandField::LiteralConst:
value = get_imm(operand.code);
break;
case OperandField::ConstFloatPos_1_0:
value = get_imm(1.f);
break;
case OperandField::ConstFloatPos_0_5:
value = get_imm(0.5f);
break;
case OperandField::ConstFloatPos_2_0:
value = get_imm(2.0f);
break;
case OperandField::ConstFloatPos_4_0:
value = get_imm(4.0f);
break;
case OperandField::ConstFloatNeg_0_5:
value = get_imm(-0.5f);
break;
case OperandField::ConstFloatNeg_1_0:
value = get_imm(-1.0f);
break;
case OperandField::ConstFloatNeg_2_0:
value = get_imm(-2.0f);
break;
case OperandField::ConstFloatNeg_4_0:
value = get_imm(-4.0f);
break;
case OperandField::VccLo: {
if (!operand.op_sel.op_sel_hi) {
value = ir.BitCast<IR::F32>(ir.GetVccLo());
} else if (operand.op_sel.op_sel) {
value = IR::F32{
ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, ir.GetVccLo()), 1)};
} else {
value = IR::F32{
ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, ir.GetVccLo()), 0)};
}
break;
}
case OperandField::VccHi:
UNREACHABLE();
break;
case OperandField::M0:
UNREACHABLE();
break;
case OperandField::Scc:
UNREACHABLE();
break;
case OperandField::Inv2Pi:
value = get_imm(static_cast<float>(1.0f / (2.0f * std::numbers::pi)));
break;
case OperandField::Sdwa:
UNREACHABLE_MSG("unhandled SDWA");
break;
case OperandField::Dpp:
UNREACHABLE_MSG("unhandled DPP");
break;
default:
UNREACHABLE_MSG("unexpected operand: {}", std::to_underlying(operand.field));
}
if (operand.input_modifier.neg_hi) {
value = ir.FPAbs(value);
}
if (operand.input_modifier.neg) {
value = ir.FPNeg(value);
}
return value;
}
template <typename T>
T Translator::GetSrc64(const InstOperand& operand) {
constexpr bool is_float = std::is_same_v<T, IR::F64>;

View File

@ -282,6 +282,11 @@ public:
void V_AND_OR_B32(const GcnInst& inst);
void V_OR3_B32(const GcnInst& inst);
// VOP3P
void V_MAD_MIX_F32(const GcnInst& inst);
void V_MAD_MIXLO_F16(const GcnInst& inst);
void V_MAD_MIXHI_F16(const GcnInst& inst);
// Vector interpolation
// VINTRP
void V_INTERP_P1_F32(const GcnInst& inst);
@ -325,6 +330,7 @@ private:
[[nodiscard]] T GetSrc16(const InstOperand& operand);
template <typename T = IR::U64>
[[nodiscard]] T GetSrc64(const InstOperand& operand);
[[nodiscard]] IR::F32 GetSrcMix(const InstOperand& operand);
void SetDst1(const InstOperand& operand, const IR::U1& value);
void SetDst(const InstOperand& operand, const IR::U32F32& value);
template <bool is_signed = false>

View File

@ -470,6 +470,15 @@ void Translator::EmitVectorAlu(const GcnInst& inst) {
return V_OR3_B32(inst);
case Opcode::V_NOP:
return;
// VOP3P
case Opcode::V_MAD_MIX_F32:
return V_MAD_MIX_F32(inst);
case Opcode::V_MAD_MIXLO_F16:
return V_MAD_MIXLO_F16(inst);
case Opcode::V_MAD_MIXHI_F16:
return V_MAD_MIXHI_F16(inst);
default:
LogMissingOpcode(inst);
}
@ -1624,6 +1633,46 @@ void Translator::V_OR3_B32(const GcnInst& inst) {
SetDst(inst.dst[0], result);
}
void Translator::V_MAD_MIX_F32(const GcnInst& inst) {
const auto src0 = GetSrcMix(inst.src[0]);
const auto src1 = GetSrcMix(inst.src[1]);
const auto src2 = GetSrcMix(inst.src[2]);
const IR::F32 result = ir.FPFma(src0, src1, src2);
SetDst(inst.dst[0], result);
}
void Translator::V_MAD_MIXLO_F16(const GcnInst& inst) {
const auto src0 = GetSrcMix(inst.src[0]);
const auto src1 = GetSrcMix(inst.src[1]);
const auto src2 = GetSrcMix(inst.src[2]);
const IR::F32 result = ir.FPFma(src0, src1, src2);
const IR::F16 result_f16 = ir.FPConvert(16, result);
const IR::U16 result_f16_u16 = ir.BitCast<IR::U16, IR::F16>(result_f16);
const IR::U32 old_value{GetSrc(inst.dst[0])};
const IR::U32 new_value{
ir.BitFieldInsert(old_value, ir.UConvert(32, result_f16_u16), ir.Imm32(0U), ir.Imm32(16U))};
SetDst(inst.dst[0], new_value);
}
void Translator::V_MAD_MIXHI_F16(const GcnInst& inst) {
const auto src0 = GetSrcMix(inst.src[0]);
const auto src1 = GetSrcMix(inst.src[1]);
const auto src2 = GetSrcMix(inst.src[2]);
const IR::F32 result = ir.FPFma(src0, src1, src2);
const IR::F16 result_f16 = ir.FPConvert(16, result);
const IR::U16 result_f16_u16 = ir.BitCast<IR::U16, IR::F16>(result_f16);
const IR::U32 old_value{GetSrc(inst.dst[0])};
const IR::U32 new_value{ir.BitFieldInsert(old_value, ir.UConvert(32, result_f16_u16),
ir.Imm32(16U), ir.Imm32(16U))};
SetDst(inst.dst[0], new_value);
}
IR::U32 Translator::GetCarryIn(const GcnInst& inst) {
IR::U1 carry;
if (inst.src_count == 3) { // VOP3

View File

@ -324,3 +324,53 @@ TEST_F(GcnTest, and_or_b32_8) {
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 0x11111111);
}
TEST_F(GcnTest, mad_mix_f32_1) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIX_F32, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({0}).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<float>(spirv, std::array{2.0f, 3.0f, 4.0f});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 10.0f);
}
TEST_F(GcnTest, mad_mix_f32_2) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIX_F32, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({1,1,0}).SetOpSel({1,0,0}).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<float>(spirv, std::array<u32,3>{
std::bit_cast<u32>(F16x2{half(44.0f), half(0.5f)}), std::bit_cast<u32>(F16x2{half(44.0f), half(0.5f)}), std::bit_cast<u32>(4.0f)}
);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, 26.0f);
}
TEST_F(GcnTest, mad_mixlo_f16_1) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIXLO_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({1,1,0}).SetOpSel({1,0,0}).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<F16x2>(spirv, std::array<u32,3>{
std::bit_cast<u32>(F16x2{half(44.0f), half(0.5f)}), std::bit_cast<u32>(F16x2{half(44.0f), half(0.5f)}), std::bit_cast<u32>(4.0f)}
);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, (F16x2{half(26.0f), half(0.5f)}));
}
TEST_F(GcnTest, mad_mixhi_f16_1) {
auto runner = gcn_test::Runner::instance().value();
auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIXHI_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({1,1,0}).SetOpSel({1,0,0}).Get();
auto spirv = TranslateToSpirv(inst);
auto result = runner->run<F16x2>(spirv, std::array<u32,3>{
std::bit_cast<u32>(F16x2{half(44.0f), half(0.5f)}), std::bit_cast<u32>(F16x2{half(44.0f), half(0.5f)}), std::bit_cast<u32>(4.0f)}
);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, (F16x2{half(44.0f), half(26.0f)}));
}