diff --git a/src/shader_recompiler/frontend/translate/translate.cpp b/src/shader_recompiler/frontend/translate/translate.cpp index cca9fff42..346553263 100644 --- a/src/shader_recompiler/frontend/translate/translate.cpp +++ b/src/shader_recompiler/frontend/translate/translate.cpp @@ -768,6 +768,134 @@ T Translator::GetSrc64(const InstOperand& operand) { template IR::U64 Translator::GetSrc64(const InstOperand&); template IR::F64 Translator::GetSrc64(const InstOperand&); +template +pk_type Translator::GetSrcPk(const InstOperand& operand) { + constexpr bool is_float = std::is_same_v; + + const auto get_imm = [&](auto value) -> pk_type { + if constexpr (is_float) { + auto imm = ir.Imm32(std::bit_cast(value)); + return {operand.op_sel.op_sel ? ir.Imm32(0.f) : imm, + operand.op_sel.op_sel_hi ? ir.Imm32(0.f) : imm}; + } else { + auto imm = ir.Imm32(std::bit_cast(value)); + return {operand.op_sel.op_sel ? ir.Imm32(0U) : imm, + operand.op_sel.op_sel_hi ? ir.Imm32(0U) : imm}; + } + }; + + constexpr auto number_format = [&]() { + if constexpr (is_float) { + return AmdGpu::NumberFormat::Float; + } else { + return AmdGpu::NumberFormat::Uint; + } + }(); + + const auto cast = [&](auto value) -> T { + if constexpr (is_float) { + return value; + } else { + return ir.BitCast(value); + } + }; + + const auto extract = [&](auto value) -> pk_type { + auto v_unpacked = ir.Unpack2x16(number_format, value); + return {cast(IR::F32{ir.CompositeExtract(v_unpacked, operand.op_sel.op_sel)}), + cast(IR::F32{ir.CompositeExtract(v_unpacked, operand.op_sel.op_sel_hi)})}; + }; + + pk_type value{}; + switch (operand.field) { + case OperandField::ScalarGPR: { + value = extract(ir.GetScalarReg(IR::ScalarReg(operand.code))); + break; + } + case OperandField::VectorGPR: { + value = extract(ir.GetVectorReg(IR::VectorReg(operand.code))); + break; + } + case OperandField::ConstZero: { + value = get_imm(0U); + break; + } + case OperandField::SignedConstIntPos: { + value = get_imm(operand.code - SignedConstIntPosMin + 1); + break; + } + case OperandField::SignedConstIntNeg: { + value = get_imm(-s32(operand.code) + SignedConstIntNegMin - 1); + break; + } + case OperandField::LiteralConst: { + value = get_imm(operand.code); + break; + } + case OperandField::ConstFloatPos_1_0: { + value = get_imm(1.f); + break; + } + case OperandField::ConstFloatPos_0_5: { + value = get_imm(0.5f); + break; + } + case OperandField::ConstFloatPos_2_0: { + value = get_imm(2.0f); + break; + } + case OperandField::ConstFloatPos_4_0: { + value = get_imm(4.0f); + break; + } + case OperandField::ConstFloatNeg_0_5: { + value = get_imm(-0.5f); + break; + } + case OperandField::ConstFloatNeg_1_0: { + value = get_imm(-1.0f); + break; + } + case OperandField::ConstFloatNeg_2_0: { + value = get_imm(-2.0f); + break; + } + case OperandField::ConstFloatNeg_4_0: { + value = get_imm(-4.0f); + break; + } + case OperandField::Inv2Pi: { + value = get_imm(1.0f / (2.0f * std::numbers::pi_v)); + break; + } + case OperandField::VccLo: + value = extract(ir.GetVccLo()); + break; + default: + UNREACHABLE_MSG("unexpected operand: {}", std::to_underlying(operand.field)); + } + + if constexpr (is_float) { + if (operand.input_modifier.neg) { + value.first = ir.FPNeg(value.first); + } + if (operand.input_modifier.neg_hi) { + value.second = ir.FPNeg(value.second); + } + } else { + if (operand.input_modifier.neg) { + value.first = ir.INeg(value.first); + } + if (operand.input_modifier.neg_hi) { + value.second = ir.INeg(value.second); + } + } + return value; +} + +template pk_type Translator::GetSrcPk(const InstOperand&); +template pk_type Translator::GetSrcPk(const InstOperand&); + void Translator::SetDst1(const InstOperand& operand, const IR::U1& value) { switch (operand.field) { case OperandField::VccLo: diff --git a/src/shader_recompiler/frontend/translate/translate.h b/src/shader_recompiler/frontend/translate/translate.h index 2e49e9c6c..e78e4bb90 100644 --- a/src/shader_recompiler/frontend/translate/translate.h +++ b/src/shader_recompiler/frontend/translate/translate.h @@ -54,6 +54,9 @@ enum class NegateMode : u32 { Result, }; +template +using pk_type = std::pair; + static constexpr size_t MaxInterpVgpr = 16; class Translator { @@ -286,6 +289,11 @@ public: void V_OR3_B32(const GcnInst& inst); // VOP3P + void V_PK_FMA_F16(const GcnInst& inst); + void V_PK_ADD_F16(const GcnInst& inst); + void V_PK_MUL_F16(const GcnInst& inst); + void V_PK_MIN_F16(const GcnInst& inst); + void V_PK_MAX_F16(const GcnInst& inst); void V_MAD_MIX_F32(const GcnInst& inst); void V_MAD_MIXLO_F16(const GcnInst& inst); void V_MAD_MIXHI_F16(const GcnInst& inst); @@ -334,6 +342,8 @@ private: template [[nodiscard]] T GetSrc64(const InstOperand& operand); [[nodiscard]] IR::F32 GetSrcMix(const InstOperand& operand); + template + [[nodiscard]] pk_type GetSrcPk(const InstOperand& operand); void SetDst1(const InstOperand& operand, const IR::U1& value); void SetDst(const InstOperand& operand, const IR::U32F32& value); template diff --git a/src/shader_recompiler/frontend/translate/vector_alu.cpp b/src/shader_recompiler/frontend/translate/vector_alu.cpp index a5406f88f..d362d314f 100644 --- a/src/shader_recompiler/frontend/translate/vector_alu.cpp +++ b/src/shader_recompiler/frontend/translate/vector_alu.cpp @@ -478,6 +478,16 @@ void Translator::EmitVectorAlu(const GcnInst& inst) { return; // VOP3P + case Opcode::V_PK_FMA_F16: + return V_PK_FMA_F16(inst); + case Opcode::V_PK_ADD_F16: + return V_PK_ADD_F16(inst); + case Opcode::V_PK_MUL_F16: + return V_PK_MUL_F16(inst); + case Opcode::V_PK_MIN_F16: + return V_PK_MIN_F16(inst); + case Opcode::V_PK_MAX_F16: + return V_PK_MAX_F16(inst); case Opcode::V_MAD_MIX_F32: return V_MAD_MIX_F32(inst); case Opcode::V_MAD_MIXLO_F16: @@ -1634,6 +1644,62 @@ void Translator::V_ADD3_U32(const GcnInst& inst) { SetDst(inst.dst[0], ir.IAdd(src0, ir.IAdd(src1, src2))); } +void Translator::V_PK_FMA_F16(const GcnInst& inst) { + const auto src0 = GetSrcPk(inst.src[0]); + const auto src1 = GetSrcPk(inst.src[1]); + const auto src2 = GetSrcPk(inst.src[2]); + + const auto result_lo = ir.FPFma(src0.first, src1.first, src2.first); + const auto result_hi = ir.FPFma(src0.second, src1.second, src2.second); + + SetDst(inst.dst[0], + ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi))); +} + +void Translator::V_PK_ADD_F16(const GcnInst& inst) { + const auto src0 = GetSrcPk(inst.src[0]); + const auto src1 = GetSrcPk(inst.src[1]); + + const auto result_lo = ir.FPAdd(src0.first, src1.first); + const auto result_hi = ir.FPAdd(src0.second, src1.second); + + SetDst(inst.dst[0], + ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi))); +} + +void Translator::V_PK_MUL_F16(const GcnInst& inst) { + const auto src0 = GetSrcPk(inst.src[0]); + const auto src1 = GetSrcPk(inst.src[1]); + + const auto result_lo = ir.FPMul(src0.first, src1.first); + const auto result_hi = ir.FPMul(src0.second, src1.second); + + SetDst(inst.dst[0], + ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi))); +} + +void Translator::V_PK_MIN_F16(const GcnInst& inst) { + const auto src0 = GetSrcPk(inst.src[0]); + const auto src1 = GetSrcPk(inst.src[1]); + + const auto result_lo = ir.FPMin(src0.first, src1.first); + const auto result_hi = ir.FPMin(src0.second, src1.second); + + SetDst(inst.dst[0], + ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi))); +} + +void Translator::V_PK_MAX_F16(const GcnInst& inst) { + const auto src0 = GetSrcPk(inst.src[0]); + const auto src1 = GetSrcPk(inst.src[1]); + + const auto result_lo = ir.FPMax(src0.first, src1.first); + const auto result_hi = ir.FPMax(src0.second, src1.second); + + SetDst(inst.dst[0], + ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi))); +} + void Translator::V_LSHL_OR_B32(const GcnInst& inst) { const auto src0 = GetSrc(inst.src[0]); const auto src1 = GetSrc(inst.src[1]); diff --git a/tests/gcn/test_gcn_instructions.cpp b/tests/gcn/test_gcn_instructions.cpp index 911f6b538..03fa8a588 100644 --- a/tests/gcn/test_gcn_instructions.cpp +++ b/tests/gcn/test_gcn_instructions.cpp @@ -424,3 +424,87 @@ TEST_F(GcnTest, ashrrev_i16_1) { EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, 0x1234F111); } + +TEST_F(GcnTest, pk_add_f16_1) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, (F16x2{half(4.0f), half(6.0f)})); +} + +TEST_F(GcnTest, pk_add_f16_2) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, 0U); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0x00003118); +} + +TEST_F(GcnTest, pk_add_f16_3) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).SetOpSel({0,1,1}).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, 0U); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0); +} + +TEST_F(GcnTest, pk_add_f16_4) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0p5, SOperand9::Const0p5).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, 0U); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0x3C00); +} + +TEST_F(GcnTest, pk_add_f16_5) { + auto runner = gcn_test::Runner::instance().value(); + + auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).SetOpSelHi({0,0,0}).Get(); + auto spirv = TranslateToSpirv(inst); + auto result = runner->run(spirv, 0U); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0x31183118); +} + +TEST_F(GcnTest, pk_add_f16_neg_lo) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNeg({1,1,0}).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, (F16x2{half(-4.0f), half(6.0f)})); +} + +TEST_F(GcnTest, pk_add_f16_neg_hi) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNegHi({1,1,0}).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, (F16x2{half(4.0f), half(-6.0f)})); +} + +TEST_F(GcnTest, pk_add_f16_op_sel_reversed) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetOpSel({1,1,1}).SetOpSelHi({0,0,0}).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, (F16x2{half(6.0f), half(4.0f)})); +}