From 99f2480e213fc8474c4c83786ac19f69a1a7175f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Miko=C5=82ajczyk?= Date: Sat, 25 Apr 2026 12:51:02 +0200 Subject: [PATCH] Neo: V_*_F16 arithmetic ops (#4311) --- src/shader_recompiler/frontend/format.cpp | 36 ++- .../frontend/translate/translate.cpp | 205 ++++++++++++++++++ .../frontend/translate/translate.h | 9 + .../frontend/translate/vector_alu.cpp | 55 +++++ tests/CMakeLists.txt | 1 + tests/gcn/instructions.hpp | 164 ++++++++++++++ tests/gcn/test_gcn_instructions.cpp | 63 ++++++ 7 files changed, 524 insertions(+), 9 deletions(-) diff --git a/src/shader_recompiler/frontend/format.cpp b/src/shader_recompiler/frontend/format.cpp index 66786de53..457b89367 100644 --- a/src/shader_recompiler/frontend/format.cpp +++ b/src/shader_recompiler/frontend/format.cpp @@ -1464,16 +1464,30 @@ constexpr std::array InstructionFormatVOP3 = { {InstClass::VectorConv, InstCategory::VectorALU, 2, 1, ScalarType::Uint32, ScalarType::Uint32}, // 305 = V_CVT_PK_I16_I32 {InstClass::VectorConv, InstCategory::VectorALU, 2, 1, ScalarType::Sint32, ScalarType::Sint32}, + // 306 = V_ADD_F16 + {InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16, + ScalarType::Float16}, + // 307 = V_SUB_F16 + {InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16, + ScalarType::Float16}, + // 308 = V_SUBREV_F16 + {InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16, + ScalarType::Float16}, + // 309 = V_MUL_F16 + {InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16, + ScalarType::Float16}, {}, {}, {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, + // 313 = V_MAX_F16 + {InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16, + ScalarType::Float16}, + // 314 = V_MIN_F16 + {InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16, + ScalarType::Float16}, + // 315 = V_LDEXP_F16 + {InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16, + ScalarType::Float16}, {}, {}, {}, @@ -1854,8 +1868,12 @@ constexpr std::array InstructionFormatVOP3 = { {}, {}, {}, - {}, - {}, + // 84 = V_RCP_F16 + {InstClass::VectorFpTran16, InstCategory::VectorALU, 1, 1, ScalarType::Float16, + ScalarType::Float16}, + // 85 = V_SQRT_F16 + {InstClass::VectorFpTran16, InstCategory::VectorALU, 1, 1, ScalarType::Float16, + ScalarType::Float16}, {}, {}, {}, diff --git a/src/shader_recompiler/frontend/translate/translate.cpp b/src/shader_recompiler/frontend/translate/translate.cpp index bd73d04be..3bf7ccb65 100644 --- a/src/shader_recompiler/frontend/translate/translate.cpp +++ b/src/shader_recompiler/frontend/translate/translate.cpp @@ -396,6 +396,150 @@ T Translator::GetSrc(const InstOperand& operand) { template IR::U32 Translator::GetSrc(const InstOperand&); template IR::F32 Translator::GetSrc(const InstOperand&); +template +T Translator::GetSrc16(const InstOperand& operand) { + constexpr bool is_float = std::is_same_v; + + const auto get_imm = [&](auto value) -> T { + if constexpr (is_float) { + return ir.Imm32(std::bit_cast(value)); + } else { + return ir.Imm32(std::bit_cast(value)); + } + }; + + const auto number_format = []() -> AmdGpu::NumberFormat { + if constexpr (is_float) { + return AmdGpu::NumberFormat::Float; + } else { + return AmdGpu::NumberFormat::Uint; + } + }(); + + const auto bitcast_to_u = [&](auto value) -> IR::U32 { + if constexpr (is_float) { + return ir.BitCast(value); + } else { + return value; + } + }; + + const auto cast = [&](auto value) -> T { + if constexpr (is_float) { + return value; + } else { + return ir.BitFieldExtract(ir.BitCast(value), ir.Imm32(0), ir.Imm32(16), + is_signed); + } + }; + + const auto op_sel = operand.op_sel.op_sel; + + T value{}; + switch (operand.field) { + case OperandField::ScalarGPR: { + const auto f = ir.GetScalarReg(IR::ScalarReg(operand.code)); + value = cast(IR::F32{ + ir.CompositeExtract(ir.Unpack2x16(number_format, bitcast_to_u(f)), op_sel ? 1 : 0)}); + break; + } + case OperandField::VectorGPR: { + const auto v = ir.GetVectorReg(IR::VectorReg(operand.code)); + value = cast(IR::F32{ + ir.CompositeExtract(ir.Unpack2x16(number_format, bitcast_to_u(v)), op_sel ? 1 : 0)}); + break; + } + case OperandField::ConstZero: + value = get_imm(0U); + break; + case OperandField::SignedConstIntPos: + value = get_imm(operand.code - SignedConstIntPosMin + 1); + break; + case OperandField::SignedConstIntNeg: + value = get_imm(-s32(operand.code) + SignedConstIntNegMin - 1); + break; + case OperandField::LiteralConst: + value = get_imm(operand.code); + break; + case OperandField::ConstFloatPos_1_0: + value = get_imm(1.f); + break; + case OperandField::ConstFloatPos_0_5: + value = get_imm(0.5f); + break; + case OperandField::ConstFloatPos_2_0: + value = get_imm(2.0f); + break; + case OperandField::ConstFloatPos_4_0: + value = get_imm(4.0f); + break; + case OperandField::ConstFloatNeg_0_5: + value = get_imm(-0.5f); + break; + case OperandField::ConstFloatNeg_1_0: + value = get_imm(-1.0f); + break; + case OperandField::ConstFloatNeg_2_0: + value = get_imm(-2.0f); + break; + case OperandField::ConstFloatNeg_4_0: + value = get_imm(-4.0f); + break; + case OperandField::Inv2Pi: + value = get_imm(static_cast(1.0f / (2.0f * std::numbers::pi))); + break; + case OperandField::Sdwa: + LOG_ERROR(Render_Recompiler, "unhandled SDWA"); + value = get_imm(0U); + break; + case OperandField::Dpp: + LOG_ERROR(Render_Recompiler, "unhandled DPP"); + value = get_imm(0U); + break; + case OperandField::VccLo: + if constexpr (is_float) { + value = IR::F32{ + ir.CompositeExtract(ir.Unpack2x16(number_format, ir.GetVccLo()), op_sel ? 1 : 0)}; + } else { + value = cast(IR::F32{ir.CompositeExtract( + ir.Unpack2x16(number_format, bitcast_to_u(ir.GetVccLo())), op_sel ? 1 : 0)}); + } + break; + case OperandField::VccHi: + UNREACHABLE(); + break; + case OperandField::M0: + UNREACHABLE(); + break; + case OperandField::Scc: + UNREACHABLE(); + break; + default: + UNREACHABLE_MSG("unexpected operand: {}", std::to_underlying(operand.field)); + } + + if constexpr (is_float) { + if (operand.input_modifier.abs) { + value = ir.FPAbs(value); + } + if (operand.input_modifier.neg) { + value = ir.FPNeg(value); + } + } else { + if (operand.input_modifier.abs) { + value = ir.BitwiseAnd(value, ir.Imm32(0x7FFFFFFFu)); + } + if (operand.input_modifier.neg) { + value = ir.BitwiseXor(value, ir.Imm32(0x80000000u)); + } + } + return value; +} + +template IR::U32 Translator::GetSrc16(const InstOperand&); +template IR::U32 Translator::GetSrc16(const InstOperand&); +template IR::F32 Translator::GetSrc16(const InstOperand&); + template T Translator::GetSrc64(const InstOperand& operand) { constexpr bool is_float = std::is_same_v; @@ -552,6 +696,67 @@ void Translator::SetDst(const InstOperand& operand, const IR::U32F32& value) { } } +template +void Translator::SetDst16(const InstOperand& operand, const IR::U32F32& value) { + IR::U32F32 result = value; + if (value.Type() == IR::Type::F32) { + if (operand.output_modifier.multiplier != 0.f) { + result = ir.FPMul(result, ir.Imm32(operand.output_modifier.multiplier)); + } + if (operand.output_modifier.clamp) { + result = ir.FPSaturate(result); + } + } else { + if (operand.output_modifier.clamp) { + if constexpr (is_signed) { + result = ir.SClamp(result, ir.Imm32(-32768), ir.Imm32(32767)); + } else { + result = ir.UMin(result, ir.Imm32(0xFFFF)); + } + } + } + + const auto cast = [&](auto value) -> IR::U32 { + if (value.Type() == IR::Type::F32) { + return ir.UConvert(32, ir.BitCast(IR::F16{ir.FPConvert(16, value)})); + } else if (value.Type() == IR::Type::U32) { + return value; + } else { + UNREACHABLE(); + } + }; + + const auto op_sel = operand.op_sel.op_sel; + + switch (operand.field) { + case OperandField::ScalarGPR: { + const auto prev_dst = ir.GetScalarReg(IR::ScalarReg(operand.code)); + const auto result_16 = cast(result); + const auto new_dst = + ir.BitFieldInsert(prev_dst, result_16, ir.Imm32(op_sel ? 16 : 0), ir.Imm32(16)); + return ir.SetScalarReg(IR::ScalarReg(operand.code), new_dst); + } + case OperandField::VectorGPR: { + const auto prev_dst = ir.GetVectorReg(IR::VectorReg(operand.code)); + const auto result_16 = cast(result); + const auto new_dst = + ir.BitFieldInsert(prev_dst, result_16, ir.Imm32(op_sel ? 16 : 0), ir.Imm32(16)); + return ir.SetVectorReg(IR::VectorReg(operand.code), new_dst); + } + case OperandField::VccLo: + UNREACHABLE(); + case OperandField::VccHi: + UNREACHABLE(); + case OperandField::M0: + UNREACHABLE(); + default: + UNREACHABLE(); + } +} + +template void Translator::SetDst16(const InstOperand&, const IR::U32F32& value); +template void Translator::SetDst16(const InstOperand&, const IR::U32F32& value); + void Translator::SetDst64(const InstOperand& operand, const IR::U64F64& value_raw) { IR::U64F64 value_untyped = value_raw; diff --git a/src/shader_recompiler/frontend/translate/translate.h b/src/shader_recompiler/frontend/translate/translate.h index c4180d843..11622f8a3 100644 --- a/src/shader_recompiler/frontend/translate/translate.h +++ b/src/shader_recompiler/frontend/translate/translate.h @@ -183,6 +183,11 @@ public: void V_CVT_PKNORM_U16_F32(const GcnInst& inst); void V_CVT_PKNORM_I16_F32(const GcnInst& inst); void V_CVT_PKRTZ_F16_F32(const GcnInst& inst); + void V_ADD_F16(const GcnInst& inst); + void V_SUB_F16(const GcnInst& inst); + void V_MUL_F16(const GcnInst& inst); + void V_MAX_F16(const GcnInst& inst); + void V_MIN_F16(const GcnInst& inst); // VOP1 void V_MOV(const GcnInst& inst); @@ -310,10 +315,14 @@ private: IR::U1 GetSrc1(const InstOperand& operand); template [[nodiscard]] T GetSrc(const InstOperand& operand); + template + [[nodiscard]] T GetSrc16(const InstOperand& operand); template [[nodiscard]] T GetSrc64(const InstOperand& operand); void SetDst1(const InstOperand& operand, const IR::U1& value); void SetDst(const InstOperand& operand, const IR::U32F32& value); + template + void SetDst16(const InstOperand& operand, const IR::U32F32& value); void SetDst64(const InstOperand& operand, const IR::U64F64& value_raw); // Vector ALU Helpers diff --git a/src/shader_recompiler/frontend/translate/vector_alu.cpp b/src/shader_recompiler/frontend/translate/vector_alu.cpp index 4c35dc435..7427d4b9e 100644 --- a/src/shader_recompiler/frontend/translate/vector_alu.cpp +++ b/src/shader_recompiler/frontend/translate/vector_alu.cpp @@ -100,6 +100,16 @@ void Translator::EmitVectorAlu(const GcnInst& inst) { return V_CVT_PKNORM_I16_F32(inst); case Opcode::V_CVT_PKRTZ_F16_F32: return V_CVT_PKRTZ_F16_F32(inst); + case Opcode::V_ADD_F16: + return V_ADD_F16(inst); + case Opcode::V_SUB_F16: + return V_SUB_F16(inst); + case Opcode::V_MUL_F16: + return V_MUL_F16(inst); + case Opcode::V_MAX_F16: + return V_MAX_F16(inst); + case Opcode::V_MIN_F16: + return V_MIN_F16(inst); // VOP1 case Opcode::V_MOV_B32: @@ -746,6 +756,51 @@ void Translator::V_CVT_PKRTZ_F16_F32(const GcnInst& inst) { SetDst(inst.dst[0], ir.Pack2x16(AmdGpu::NumberFormat::Float, vec_f32)); } +void Translator::V_ADD_F16(const GcnInst& inst) { + const auto src0 = GetSrc16(inst.src[0]); + const auto src1 = GetSrc16(inst.src[1]); + + const auto result = ir.FPAdd(src0, src1); + + SetDst16(inst.dst[0], result); +} + +void Translator::V_SUB_F16(const GcnInst& inst) { + const auto src0 = GetSrc16(inst.src[0]); + const auto src1 = GetSrc16(inst.src[1]); + + const auto result = ir.FPSub(src0, src1); + + SetDst16(inst.dst[0], result); +} + +void Translator::V_MUL_F16(const GcnInst& inst) { + const auto src0 = GetSrc16(inst.src[0]); + const auto src1 = GetSrc16(inst.src[1]); + + const auto result = ir.FPMul(src0, src1); + + SetDst16(inst.dst[0], result); +} + +void Translator::V_MAX_F16(const GcnInst& inst) { + const auto src0 = GetSrc16(inst.src[0]); + const auto src1 = GetSrc16(inst.src[1]); + + const auto result = ir.FPMax(src0, src1); + + SetDst16(inst.dst[0], result); +} + +void Translator::V_MIN_F16(const GcnInst& inst) { + const auto src0 = GetSrc16(inst.src[0]); + const auto src1 = GetSrc16(inst.src[1]); + + const auto result = ir.FPMin(src0, src1); + + SetDst16(inst.dst[0], result); +} + // VOP1 void Translator::V_MOV(const GcnInst& inst) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 43fc3b62f..63a13863b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -156,6 +156,7 @@ target_link_libraries(shadps4_gcn_test PRIVATE sirit SDL3::SDL3 spdlog::spdlog + half::half ) if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR diff --git a/tests/gcn/instructions.hpp b/tests/gcn/instructions.hpp index 2e43ba332..8620b29fb 100644 --- a/tests/gcn/instructions.hpp +++ b/tests/gcn/instructions.hpp @@ -3,12 +3,15 @@ #include "shader_recompiler/frontend/opcodes.h" +#include using OpcodeSOP1 = Shader::Gcn::OpcodeSOP1; using OpcodeSOP2 = Shader::Gcn::OpcodeSOP2; using OpcodeSOPK = Shader::Gcn::OpcodeSOPK; using OpcodeVOP1 = Shader::Gcn::OpcodeVOP1; using OpcodeVOP2 = Shader::Gcn::OpcodeVOP2; +using OpcodeVOP3 = Shader::Gcn::OpcodeVOP3; +using OpcodeVOP3P = Shader::Gcn::OpcodeVOP3P; enum class VOperand8 : u8 { V0 = 0, @@ -1185,3 +1188,164 @@ private: static_assert(sizeof(VOP2Internal) == sizeof(u32)); }; + +enum class Omod : u8 { + None = 0, + Mul2 = 1, + Mul4 = 2, + Div2 = 3, +}; + +class VOP3A { +public: + explicit constexpr VOP3A(OpcodeVOP3 op, VOperand8 vdst, SOperand9 src0, SOperand9 src1, SOperand9 src2 = SOperand9::S0) { + i.src0 = std::to_underlying(src0); + i.src1 = std::to_underlying(src1); + i.src2 = std::to_underlying(src2); + i.vdst = std::to_underlying(vdst); + i.op = std::to_underlying(op) & 0x1FF; + i.op_msb = std::to_underlying(op) >> 9; + i.encoding = 0b110100; + } + + VOP3A& SetAbs(const std::array& abs) { + u8 a = 0; + for (auto x : abs | std::views::reverse) { + a |= x; + a <<= 1; + } + i.abs = a >> 1; + return *this; + } + + VOP3A& SetClamp(bool clamp) { + i.clmp = clamp; + return *this; + } + + VOP3A& SetNeg(const std::array& neg) { + u8 n = 0; + for (auto x : neg | std::views::reverse) { + n |= x; + n <<= 1; + } + i.neg = n >> 1; + return *this; + } + + VOP3A& SetOmod(Omod omod) { + i.omod = std::to_underlying(omod); + return *this; + } + + VOP3A& SetOpSel(const std::array& op_sel) { + u8 o = 0; + for (auto x : op_sel | std::views::reverse) { + o |= x; + o <<= 1; + } + i.op_sel = o >> 1; + return *this; + } + + u64 Get() { + return std::bit_cast(i); + } + +private: + struct VOP3Internal { + u64 vdst : 8; + u64 abs : 3; + u64 clmp : 1; + u64 op_sel : 4; + u64 op_msb : 1; + u64 op : 9; + u64 encoding : 6; + u64 src0 : 9; + u64 src1 : 9; + u64 src2 : 9; + u64 omod : 2; + u64 neg : 3; + } i; + + static_assert(sizeof(VOP3Internal) == sizeof(u64)); +}; + +class VOP3P { +public: + explicit constexpr VOP3P(OpcodeVOP3P op, VOperand8 vdst, SOperand9 src0, SOperand9 src1, SOperand9 src2 = SOperand9::S0) { + i.src0 = std::to_underlying(src0); + i.src1 = std::to_underlying(src1); + i.src2 = std::to_underlying(src2); + i.vdst = std::to_underlying(vdst); + i.op = std::to_underlying(op); + // enable op_sel_hi as a sensible default, so both 16-bit chunks are processed + i.op_sel_hi01 = 0b11; + i.op_sel_hi2 = 0b1; + i.encoding = 0b110011; + } + + VOP3P& SetClamp(bool clamp) { + i.clmp = clamp; + return *this; + } + + VOP3P& SetNeg(const std::array& neg) { + u8 n = 0; + for (auto x : neg | std::views::reverse) { + n |= x; + n <<= 1; + } + i.neg = n >> 1; + return *this; + } + + VOP3P& SetNegHi(const std::array& neg) { + u8 n = 0; + for (auto x : neg | std::views::reverse) { + n |= x; + n <<= 1; + } + i.neg_hi = n >> 1; + return *this; + } + + VOP3P& SetOpSel(const std::array& op_sel) { + u8 o = 0; + for (auto x : op_sel | std::views::reverse) { + o |= x; + o <<= 1; + } + i.op_sel = o >> 1; + return *this; + } + + VOP3P& SetOpSelHi(const std::array& op_sel) { + i.op_sel_hi2 = op_sel[2]; + i.op_sel_hi01 = (op_sel[1] << 1) | op_sel[0]; + return *this; + } + + u64 Get() { + return std::bit_cast(i); + } + +private: + struct VOP3PInternal { + u64 vdst : 8; + u64 neg_hi : 3; + u64 op_sel : 3; + u64 op_sel_hi2 : 1; + u64 clmp : 1; + u64 op : 7; + u64 : 3; + u64 encoding : 6; + u64 src0 : 9; + u64 src1 : 9; + u64 src2 : 9; + u64 op_sel_hi01 : 2; + u64 neg : 3; + } i; + + static_assert(sizeof(VOP3PInternal) == sizeof(u64)); +}; diff --git a/tests/gcn/test_gcn_instructions.cpp b/tests/gcn/test_gcn_instructions.cpp index c7e40ad48..d88c8d6a5 100644 --- a/tests/gcn/test_gcn_instructions.cpp +++ b/tests/gcn/test_gcn_instructions.cpp @@ -4,6 +4,7 @@ #include #include +#include #include "gcn_test_runner.hpp" #include "instructions.hpp" @@ -58,3 +59,65 @@ TEST_F(GcnTest, add_nan) { EXPECT_TRUE(result.has_value()); EXPECT_TRUE(std::isnan(*result)); } + +using half = half_float::half; + +struct F16x2 { + half a; + half b = half(0.0f); + + bool operator==(const F16x2& rhs) const = default; +}; + +static_assert(sizeof(F16x2) == sizeof(float)); + +TEST_F(GcnTest, add_f16) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP2(OpcodeVOP2::V_ADD_F16, VOperand8::V0, SOperand9::V0, VOperand8::V1).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f)}, F16x2{half(1.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, F16x2{half(2.0f)}); +} + +TEST_F(GcnTest, add_f16_clamp) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP3A(OpcodeVOP3::V_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetClamp(true).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f)}, F16x2{half(1.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, F16x2{half(1.0f)}); //confirmed with neo +} + +TEST_F(GcnTest, add_f16_neg) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP3A(OpcodeVOP3::V_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNeg({true, true, false}).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f)}, F16x2{half(1.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ((*result).a, half(-2.0f)); //confirmed with neo +} + +TEST_F(GcnTest, add_f16_opsel_hi) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP3A(OpcodeVOP3::V_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetOpSel({true, true, false, true}).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(1.0f), half(2.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ((*result).a, half(1.0f)); + EXPECT_EQ((*result).b, half(4.0f)); +} + +TEST_F(GcnTest, sub_f16) { + auto runner = gcn_test::Runner::instance().value(); + + auto spirv = TranslateToSpirv(VOP2(OpcodeVOP2::V_SUB_F16, VOperand8::V0, SOperand9::V0, VOperand8::V1).Get()); + auto result = runner->run(spirv, std::array{F16x2{half(0.0f)}, F16x2{half(1.0f)}}); + + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(*result, F16x2{half(-1.0f)}); //confirmed with neo +}