diff --git a/src/shader_recompiler/frontend/translate/translate.cpp b/src/shader_recompiler/frontend/translate/translate.cpp index e663bf369..cca9fff42 100644 --- a/src/shader_recompiler/frontend/translate/translate.cpp +++ b/src/shader_recompiler/frontend/translate/translate.cpp @@ -540,6 +540,121 @@ template IR::U32 Translator::GetSrc16(const InstOperand&); template IR::U32 Translator::GetSrc16(const InstOperand&); template IR::F32 Translator::GetSrc16(const InstOperand&); +IR::F32 Translator::GetSrcMix(const InstOperand& operand) { + const auto get_imm = [&](auto value) -> IR::F32 { + return ir.Imm32(std::bit_cast(value)); + }; + + const auto extract = [&](auto value) -> IR::F32 { + const auto getter_u = [&]() { + if constexpr (std::same_as) { + return ir.GetScalarReg(value); + } else { + return ir.GetVectorReg(value); + } + }(); + if (!operand.op_sel.op_sel_hi) { + if constexpr (std::same_as) { + return ir.GetScalarReg(value); + } else { + return ir.GetVectorReg(value); + } + } else if (operand.op_sel.op_sel) { + return IR::F32{ + ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, getter_u), 1)}; + } else { + return IR::F32{ + ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, getter_u), 0)}; + } + }; + + IR::F32 value{}; + switch (operand.field) { + case OperandField::ScalarGPR: + value = extract(IR::ScalarReg(operand.code)); + break; + case OperandField::VectorGPR: + value = extract(IR::VectorReg(operand.code)); + break; + case OperandField::ConstZero: + value = get_imm(0U); + break; + case OperandField::SignedConstIntPos: + value = get_imm(operand.code - SignedConstIntPosMin + 1); + break; + case OperandField::SignedConstIntNeg: + value = get_imm(-s32(operand.code) + SignedConstIntNegMin - 1); + break; + case OperandField::LiteralConst: + value = get_imm(operand.code); + break; + case OperandField::ConstFloatPos_1_0: + value = get_imm(1.f); + break; + case OperandField::ConstFloatPos_0_5: + value = get_imm(0.5f); + break; + case OperandField::ConstFloatPos_2_0: + value = get_imm(2.0f); + break; + case OperandField::ConstFloatPos_4_0: + value = get_imm(4.0f); + break; + case OperandField::ConstFloatNeg_0_5: + value = get_imm(-0.5f); + break; + case OperandField::ConstFloatNeg_1_0: + value = get_imm(-1.0f); + break; + case OperandField::ConstFloatNeg_2_0: + value = get_imm(-2.0f); + break; + case OperandField::ConstFloatNeg_4_0: + value = get_imm(-4.0f); + break; + case OperandField::VccLo: { + if (!operand.op_sel.op_sel_hi) { + value = ir.BitCast(ir.GetVccLo()); + } else if (operand.op_sel.op_sel) { + value = IR::F32{ + ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, ir.GetVccLo()), 1)}; + } else { + value = IR::F32{ + ir.CompositeExtract(ir.Unpack2x16(AmdGpu::NumberFormat::Float, ir.GetVccLo()), 0)}; + } + break; + } + case OperandField::VccHi: + UNREACHABLE(); + break; + case OperandField::M0: + UNREACHABLE(); + break; + case OperandField::Scc: + UNREACHABLE(); + break; + case OperandField::Inv2Pi: + value = get_imm(static_cast(1.0f / (2.0f * std::numbers::pi))); + break; + case OperandField::Sdwa: + UNREACHABLE_MSG("unhandled SDWA"); + break; + case OperandField::Dpp: + UNREACHABLE_MSG("unhandled DPP"); + break; + default: + UNREACHABLE_MSG("unexpected operand: {}", std::to_underlying(operand.field)); + } + + if (operand.input_modifier.neg_hi) { + value = ir.FPAbs(value); + } + if (operand.input_modifier.neg) { + value = ir.FPNeg(value); + } + return value; +} + template T Translator::GetSrc64(const InstOperand& operand) { constexpr bool is_float = std::is_same_v; diff --git a/src/shader_recompiler/frontend/translate/translate.h b/src/shader_recompiler/frontend/translate/translate.h index aa6ee9601..b3d54ddd0 100644 --- a/src/shader_recompiler/frontend/translate/translate.h +++ b/src/shader_recompiler/frontend/translate/translate.h @@ -282,6 +282,11 @@ public: void V_AND_OR_B32(const GcnInst& inst); void V_OR3_B32(const GcnInst& inst); + // VOP3P + void V_MAD_MIX_F32(const GcnInst& inst); + void V_MAD_MIXLO_F16(const GcnInst& inst); + void V_MAD_MIXHI_F16(const GcnInst& inst); + // Vector interpolation // VINTRP void V_INTERP_P1_F32(const GcnInst& inst); @@ -325,6 +330,7 @@ private: [[nodiscard]] T GetSrc16(const InstOperand& operand); template [[nodiscard]] T GetSrc64(const InstOperand& operand); + [[nodiscard]] IR::F32 GetSrcMix(const InstOperand& operand); void SetDst1(const InstOperand& operand, const IR::U1& value); void SetDst(const InstOperand& operand, const IR::U32F32& value); template diff --git a/src/shader_recompiler/frontend/translate/vector_alu.cpp b/src/shader_recompiler/frontend/translate/vector_alu.cpp index f9f7bee06..ab5e87fae 100644 --- a/src/shader_recompiler/frontend/translate/vector_alu.cpp +++ b/src/shader_recompiler/frontend/translate/vector_alu.cpp @@ -470,6 +470,15 @@ void Translator::EmitVectorAlu(const GcnInst& inst) { return V_OR3_B32(inst); case Opcode::V_NOP: return; + + // VOP3P + case Opcode::V_MAD_MIX_F32: + return V_MAD_MIX_F32(inst); + case Opcode::V_MAD_MIXLO_F16: + return V_MAD_MIXLO_F16(inst); + case Opcode::V_MAD_MIXHI_F16: + return V_MAD_MIXHI_F16(inst); + default: LogMissingOpcode(inst); } @@ -1624,6 +1633,46 @@ void Translator::V_OR3_B32(const GcnInst& inst) { SetDst(inst.dst[0], result); } +void Translator::V_MAD_MIX_F32(const GcnInst& inst) { + const auto src0 = GetSrcMix(inst.src[0]); + const auto src1 = GetSrcMix(inst.src[1]); + const auto src2 = GetSrcMix(inst.src[2]); + + const IR::F32 result = ir.FPFma(src0, src1, src2); + + SetDst(inst.dst[0], result); +} + +void Translator::V_MAD_MIXLO_F16(const GcnInst& inst) { + const auto src0 = GetSrcMix(inst.src[0]); + const auto src1 = GetSrcMix(inst.src[1]); + const auto src2 = GetSrcMix(inst.src[2]); + + const IR::F32 result = ir.FPFma(src0, src1, src2); + const IR::F16 result_f16 = ir.FPConvert(16, result); + const IR::U16 result_f16_u16 = ir.BitCast(result_f16); + + const IR::U32 old_value{GetSrc(inst.dst[0])}; + const IR::U32 new_value{ + ir.BitFieldInsert(old_value, ir.UConvert(32, result_f16_u16), ir.Imm32(0U), ir.Imm32(16U))}; + SetDst(inst.dst[0], new_value); +} + +void Translator::V_MAD_MIXHI_F16(const GcnInst& inst) { + const auto src0 = GetSrcMix(inst.src[0]); + const auto src1 = GetSrcMix(inst.src[1]); + const auto src2 = GetSrcMix(inst.src[2]); + + const IR::F32 result = ir.FPFma(src0, src1, src2); + const IR::F16 result_f16 = ir.FPConvert(16, result); + const IR::U16 result_f16_u16 = ir.BitCast(result_f16); + + const IR::U32 old_value{GetSrc(inst.dst[0])}; + const IR::U32 new_value{ir.BitFieldInsert(old_value, ir.UConvert(32, result_f16_u16), + ir.Imm32(16U), ir.Imm32(16U))}; + SetDst(inst.dst[0], new_value); +} + IR::U32 Translator::GetCarryIn(const GcnInst& inst) { IR::U1 carry; if (inst.src_count == 3) { // VOP3 diff --git a/tests/gcn/test_gcn_instructions.cpp b/tests/gcn/test_gcn_instructions.cpp index 59fb95c14..6928c6b20 100644 --- a/tests/gcn/test_gcn_instructions.cpp +++ b/tests/gcn/test_gcn_instructions.cpp @@ -324,3 +324,53 @@ TEST_F(GcnTest, and_or_b32_8) { EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, 0x11111111); } + +TEST_F(GcnTest, mad_mix_f32_1) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIX_F32, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({0}).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, std::array{2.0f, 3.0f, 4.0f}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 10.0f); +} + +TEST_F(GcnTest, mad_mix_f32_2) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIX_F32, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({1,1,0}).SetOpSel({1,0,0}).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, std::array{ + std::bit_cast(F16x2{half(44.0f), half(0.5f)}), std::bit_cast(F16x2{half(44.0f), half(0.5f)}), std::bit_cast(4.0f)} + ); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 26.0f); +} + +TEST_F(GcnTest, mad_mixlo_f16_1) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIXLO_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({1,1,0}).SetOpSel({1,0,0}).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, std::array{ + std::bit_cast(F16x2{half(44.0f), half(0.5f)}), std::bit_cast(F16x2{half(44.0f), half(0.5f)}), std::bit_cast(4.0f)} + ); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, (F16x2{half(26.0f), half(0.5f)})); +} + +TEST_F(GcnTest, mad_mixhi_f16_1) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_MAD_MIXHI_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1, SOperand9::V2).SetOpSelHi({1,1,0}).SetOpSel({1,0,0}).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, std::array{ + std::bit_cast(F16x2{half(44.0f), half(0.5f)}), std::bit_cast(F16x2{half(44.0f), half(0.5f)}), std::bit_cast(4.0f)} + ); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, (F16x2{half(44.0f), half(26.0f)})); +}