Neo: V_*_F16 arithmetic ops (#4311)

This commit is contained in:
Marcin Mikołajczyk 2026-04-25 12:51:02 +02:00 committed by GitHub
parent 2c9eee1eb5
commit 99f2480e21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 524 additions and 9 deletions

View File

@ -1464,16 +1464,30 @@ constexpr std::array<InstFormat, 883> InstructionFormatVOP3 = {
{InstClass::VectorConv, InstCategory::VectorALU, 2, 1, ScalarType::Uint32, ScalarType::Uint32},
// 305 = V_CVT_PK_I16_I32
{InstClass::VectorConv, InstCategory::VectorALU, 2, 1, ScalarType::Sint32, ScalarType::Sint32},
// 306 = V_ADD_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16,
ScalarType::Float16},
// 307 = V_SUB_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16,
ScalarType::Float16},
// 308 = V_SUBREV_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16,
ScalarType::Float16},
// 309 = V_MUL_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16,
ScalarType::Float16},
{},
{},
{},
{},
{},
{},
{},
{},
{},
{},
// 313 = V_MAX_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16,
ScalarType::Float16},
// 314 = V_MIN_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16,
ScalarType::Float16},
// 315 = V_LDEXP_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 2, 1, ScalarType::Float16,
ScalarType::Float16},
{},
{},
{},
@ -1854,8 +1868,12 @@ constexpr std::array<InstFormat, 883> InstructionFormatVOP3 = {
{},
{},
{},
{},
{},
// 84 = V_RCP_F16
{InstClass::VectorFpTran16, InstCategory::VectorALU, 1, 1, ScalarType::Float16,
ScalarType::Float16},
// 85 = V_SQRT_F16
{InstClass::VectorFpTran16, InstCategory::VectorALU, 1, 1, ScalarType::Float16,
ScalarType::Float16},
{},
{},
{},

View File

@ -396,6 +396,150 @@ T Translator::GetSrc(const InstOperand& operand) {
template IR::U32 Translator::GetSrc<IR::U32>(const InstOperand&);
template IR::F32 Translator::GetSrc<IR::F32>(const InstOperand&);
template <typename T, bool is_signed>
T Translator::GetSrc16(const InstOperand& operand) {
constexpr bool is_float = std::is_same_v<T, IR::F32>;
const auto get_imm = [&](auto value) -> T {
if constexpr (is_float) {
return ir.Imm32(std::bit_cast<float>(value));
} else {
return ir.Imm32(std::bit_cast<u32>(value));
}
};
const auto number_format = []() -> AmdGpu::NumberFormat {
if constexpr (is_float) {
return AmdGpu::NumberFormat::Float;
} else {
return AmdGpu::NumberFormat::Uint;
}
}();
const auto bitcast_to_u = [&](auto value) -> IR::U32 {
if constexpr (is_float) {
return ir.BitCast<IR::U32>(value);
} else {
return value;
}
};
const auto cast = [&](auto value) -> T {
if constexpr (is_float) {
return value;
} else {
return ir.BitFieldExtract(ir.BitCast<IR::U32>(value), ir.Imm32(0), ir.Imm32(16),
is_signed);
}
};
const auto op_sel = operand.op_sel.op_sel;
T value{};
switch (operand.field) {
case OperandField::ScalarGPR: {
const auto f = ir.GetScalarReg<T>(IR::ScalarReg(operand.code));
value = cast(IR::F32{
ir.CompositeExtract(ir.Unpack2x16(number_format, bitcast_to_u(f)), op_sel ? 1 : 0)});
break;
}
case OperandField::VectorGPR: {
const auto v = ir.GetVectorReg<T>(IR::VectorReg(operand.code));
value = cast(IR::F32{
ir.CompositeExtract(ir.Unpack2x16(number_format, bitcast_to_u(v)), op_sel ? 1 : 0)});
break;
}
case OperandField::ConstZero:
value = get_imm(0U);
break;
case OperandField::SignedConstIntPos:
value = get_imm(operand.code - SignedConstIntPosMin + 1);
break;
case OperandField::SignedConstIntNeg:
value = get_imm(-s32(operand.code) + SignedConstIntNegMin - 1);
break;
case OperandField::LiteralConst:
value = get_imm(operand.code);
break;
case OperandField::ConstFloatPos_1_0:
value = get_imm(1.f);
break;
case OperandField::ConstFloatPos_0_5:
value = get_imm(0.5f);
break;
case OperandField::ConstFloatPos_2_0:
value = get_imm(2.0f);
break;
case OperandField::ConstFloatPos_4_0:
value = get_imm(4.0f);
break;
case OperandField::ConstFloatNeg_0_5:
value = get_imm(-0.5f);
break;
case OperandField::ConstFloatNeg_1_0:
value = get_imm(-1.0f);
break;
case OperandField::ConstFloatNeg_2_0:
value = get_imm(-2.0f);
break;
case OperandField::ConstFloatNeg_4_0:
value = get_imm(-4.0f);
break;
case OperandField::Inv2Pi:
value = get_imm(static_cast<float>(1.0f / (2.0f * std::numbers::pi)));
break;
case OperandField::Sdwa:
LOG_ERROR(Render_Recompiler, "unhandled SDWA");
value = get_imm(0U);
break;
case OperandField::Dpp:
LOG_ERROR(Render_Recompiler, "unhandled DPP");
value = get_imm(0U);
break;
case OperandField::VccLo:
if constexpr (is_float) {
value = IR::F32{
ir.CompositeExtract(ir.Unpack2x16(number_format, ir.GetVccLo()), op_sel ? 1 : 0)};
} else {
value = cast(IR::F32{ir.CompositeExtract(
ir.Unpack2x16(number_format, bitcast_to_u(ir.GetVccLo())), op_sel ? 1 : 0)});
}
break;
case OperandField::VccHi:
UNREACHABLE();
break;
case OperandField::M0:
UNREACHABLE();
break;
case OperandField::Scc:
UNREACHABLE();
break;
default:
UNREACHABLE_MSG("unexpected operand: {}", std::to_underlying(operand.field));
}
if constexpr (is_float) {
if (operand.input_modifier.abs) {
value = ir.FPAbs(value);
}
if (operand.input_modifier.neg) {
value = ir.FPNeg(value);
}
} else {
if (operand.input_modifier.abs) {
value = ir.BitwiseAnd(value, ir.Imm32(0x7FFFFFFFu));
}
if (operand.input_modifier.neg) {
value = ir.BitwiseXor(value, ir.Imm32(0x80000000u));
}
}
return value;
}
template IR::U32 Translator::GetSrc16<IR::U32, false>(const InstOperand&);
template IR::U32 Translator::GetSrc16<IR::U32, true>(const InstOperand&);
template IR::F32 Translator::GetSrc16<IR::F32, false>(const InstOperand&);
template <typename T>
T Translator::GetSrc64(const InstOperand& operand) {
constexpr bool is_float = std::is_same_v<T, IR::F64>;
@ -552,6 +696,67 @@ void Translator::SetDst(const InstOperand& operand, const IR::U32F32& value) {
}
}
template <bool is_signed>
void Translator::SetDst16(const InstOperand& operand, const IR::U32F32& value) {
IR::U32F32 result = value;
if (value.Type() == IR::Type::F32) {
if (operand.output_modifier.multiplier != 0.f) {
result = ir.FPMul(result, ir.Imm32(operand.output_modifier.multiplier));
}
if (operand.output_modifier.clamp) {
result = ir.FPSaturate(result);
}
} else {
if (operand.output_modifier.clamp) {
if constexpr (is_signed) {
result = ir.SClamp(result, ir.Imm32(-32768), ir.Imm32(32767));
} else {
result = ir.UMin(result, ir.Imm32(0xFFFF));
}
}
}
const auto cast = [&](auto value) -> IR::U32 {
if (value.Type() == IR::Type::F32) {
return ir.UConvert(32, ir.BitCast<IR::U16>(IR::F16{ir.FPConvert(16, value)}));
} else if (value.Type() == IR::Type::U32) {
return value;
} else {
UNREACHABLE();
}
};
const auto op_sel = operand.op_sel.op_sel;
switch (operand.field) {
case OperandField::ScalarGPR: {
const auto prev_dst = ir.GetScalarReg<IR::U32>(IR::ScalarReg(operand.code));
const auto result_16 = cast(result);
const auto new_dst =
ir.BitFieldInsert(prev_dst, result_16, ir.Imm32(op_sel ? 16 : 0), ir.Imm32(16));
return ir.SetScalarReg(IR::ScalarReg(operand.code), new_dst);
}
case OperandField::VectorGPR: {
const auto prev_dst = ir.GetVectorReg<IR::U32>(IR::VectorReg(operand.code));
const auto result_16 = cast(result);
const auto new_dst =
ir.BitFieldInsert(prev_dst, result_16, ir.Imm32(op_sel ? 16 : 0), ir.Imm32(16));
return ir.SetVectorReg(IR::VectorReg(operand.code), new_dst);
}
case OperandField::VccLo:
UNREACHABLE();
case OperandField::VccHi:
UNREACHABLE();
case OperandField::M0:
UNREACHABLE();
default:
UNREACHABLE();
}
}
template void Translator::SetDst16<false>(const InstOperand&, const IR::U32F32& value);
template void Translator::SetDst16<true>(const InstOperand&, const IR::U32F32& value);
void Translator::SetDst64(const InstOperand& operand, const IR::U64F64& value_raw) {
IR::U64F64 value_untyped = value_raw;

View File

@ -183,6 +183,11 @@ public:
void V_CVT_PKNORM_U16_F32(const GcnInst& inst);
void V_CVT_PKNORM_I16_F32(const GcnInst& inst);
void V_CVT_PKRTZ_F16_F32(const GcnInst& inst);
void V_ADD_F16(const GcnInst& inst);
void V_SUB_F16(const GcnInst& inst);
void V_MUL_F16(const GcnInst& inst);
void V_MAX_F16(const GcnInst& inst);
void V_MIN_F16(const GcnInst& inst);
// VOP1
void V_MOV(const GcnInst& inst);
@ -310,10 +315,14 @@ private:
IR::U1 GetSrc1(const InstOperand& operand);
template <typename T = IR::U32>
[[nodiscard]] T GetSrc(const InstOperand& operand);
template <typename T = IR::U32, bool is_signed = false>
[[nodiscard]] T GetSrc16(const InstOperand& operand);
template <typename T = IR::U64>
[[nodiscard]] T GetSrc64(const InstOperand& operand);
void SetDst1(const InstOperand& operand, const IR::U1& value);
void SetDst(const InstOperand& operand, const IR::U32F32& value);
template <bool is_signed = false>
void SetDst16(const InstOperand& operand, const IR::U32F32& value);
void SetDst64(const InstOperand& operand, const IR::U64F64& value_raw);
// Vector ALU Helpers

View File

@ -100,6 +100,16 @@ void Translator::EmitVectorAlu(const GcnInst& inst) {
return V_CVT_PKNORM_I16_F32(inst);
case Opcode::V_CVT_PKRTZ_F16_F32:
return V_CVT_PKRTZ_F16_F32(inst);
case Opcode::V_ADD_F16:
return V_ADD_F16(inst);
case Opcode::V_SUB_F16:
return V_SUB_F16(inst);
case Opcode::V_MUL_F16:
return V_MUL_F16(inst);
case Opcode::V_MAX_F16:
return V_MAX_F16(inst);
case Opcode::V_MIN_F16:
return V_MIN_F16(inst);
// VOP1
case Opcode::V_MOV_B32:
@ -746,6 +756,51 @@ void Translator::V_CVT_PKRTZ_F16_F32(const GcnInst& inst) {
SetDst(inst.dst[0], ir.Pack2x16(AmdGpu::NumberFormat::Float, vec_f32));
}
void Translator::V_ADD_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto result = ir.FPAdd(src0, src1);
SetDst16(inst.dst[0], result);
}
void Translator::V_SUB_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto result = ir.FPSub(src0, src1);
SetDst16(inst.dst[0], result);
}
void Translator::V_MUL_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto result = ir.FPMul(src0, src1);
SetDst16(inst.dst[0], result);
}
void Translator::V_MAX_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto result = ir.FPMax(src0, src1);
SetDst16(inst.dst[0], result);
}
void Translator::V_MIN_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto result = ir.FPMin(src0, src1);
SetDst16(inst.dst[0], result);
}
// VOP1
void Translator::V_MOV(const GcnInst& inst) {

View File

@ -156,6 +156,7 @@ target_link_libraries(shadps4_gcn_test PRIVATE
sirit
SDL3::SDL3
spdlog::spdlog
half::half
)
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR

View File

@ -3,12 +3,15 @@
#include "shader_recompiler/frontend/opcodes.h"
#include <ranges>
using OpcodeSOP1 = Shader::Gcn::OpcodeSOP1;
using OpcodeSOP2 = Shader::Gcn::OpcodeSOP2;
using OpcodeSOPK = Shader::Gcn::OpcodeSOPK;
using OpcodeVOP1 = Shader::Gcn::OpcodeVOP1;
using OpcodeVOP2 = Shader::Gcn::OpcodeVOP2;
using OpcodeVOP3 = Shader::Gcn::OpcodeVOP3;
using OpcodeVOP3P = Shader::Gcn::OpcodeVOP3P;
enum class VOperand8 : u8 {
V0 = 0,
@ -1185,3 +1188,164 @@ private:
static_assert(sizeof(VOP2Internal) == sizeof(u32));
};
enum class Omod : u8 {
None = 0,
Mul2 = 1,
Mul4 = 2,
Div2 = 3,
};
class VOP3A {
public:
explicit constexpr VOP3A(OpcodeVOP3 op, VOperand8 vdst, SOperand9 src0, SOperand9 src1, SOperand9 src2 = SOperand9::S0) {
i.src0 = std::to_underlying(src0);
i.src1 = std::to_underlying(src1);
i.src2 = std::to_underlying(src2);
i.vdst = std::to_underlying(vdst);
i.op = std::to_underlying(op) & 0x1FF;
i.op_msb = std::to_underlying(op) >> 9;
i.encoding = 0b110100;
}
VOP3A& SetAbs(const std::array<bool, 3>& abs) {
u8 a = 0;
for (auto x : abs | std::views::reverse) {
a |= x;
a <<= 1;
}
i.abs = a >> 1;
return *this;
}
VOP3A& SetClamp(bool clamp) {
i.clmp = clamp;
return *this;
}
VOP3A& SetNeg(const std::array<bool, 3>& neg) {
u8 n = 0;
for (auto x : neg | std::views::reverse) {
n |= x;
n <<= 1;
}
i.neg = n >> 1;
return *this;
}
VOP3A& SetOmod(Omod omod) {
i.omod = std::to_underlying(omod);
return *this;
}
VOP3A& SetOpSel(const std::array<bool, 4>& op_sel) {
u8 o = 0;
for (auto x : op_sel | std::views::reverse) {
o |= x;
o <<= 1;
}
i.op_sel = o >> 1;
return *this;
}
u64 Get() {
return std::bit_cast<u64>(i);
}
private:
struct VOP3Internal {
u64 vdst : 8;
u64 abs : 3;
u64 clmp : 1;
u64 op_sel : 4;
u64 op_msb : 1;
u64 op : 9;
u64 encoding : 6;
u64 src0 : 9;
u64 src1 : 9;
u64 src2 : 9;
u64 omod : 2;
u64 neg : 3;
} i;
static_assert(sizeof(VOP3Internal) == sizeof(u64));
};
class VOP3P {
public:
explicit constexpr VOP3P(OpcodeVOP3P op, VOperand8 vdst, SOperand9 src0, SOperand9 src1, SOperand9 src2 = SOperand9::S0) {
i.src0 = std::to_underlying(src0);
i.src1 = std::to_underlying(src1);
i.src2 = std::to_underlying(src2);
i.vdst = std::to_underlying(vdst);
i.op = std::to_underlying(op);
// enable op_sel_hi as a sensible default, so both 16-bit chunks are processed
i.op_sel_hi01 = 0b11;
i.op_sel_hi2 = 0b1;
i.encoding = 0b110011;
}
VOP3P& SetClamp(bool clamp) {
i.clmp = clamp;
return *this;
}
VOP3P& SetNeg(const std::array<bool, 3>& neg) {
u8 n = 0;
for (auto x : neg | std::views::reverse) {
n |= x;
n <<= 1;
}
i.neg = n >> 1;
return *this;
}
VOP3P& SetNegHi(const std::array<bool, 3>& neg) {
u8 n = 0;
for (auto x : neg | std::views::reverse) {
n |= x;
n <<= 1;
}
i.neg_hi = n >> 1;
return *this;
}
VOP3P& SetOpSel(const std::array<bool, 3>& op_sel) {
u8 o = 0;
for (auto x : op_sel | std::views::reverse) {
o |= x;
o <<= 1;
}
i.op_sel = o >> 1;
return *this;
}
VOP3P& SetOpSelHi(const std::array<bool, 3>& op_sel) {
i.op_sel_hi2 = op_sel[2];
i.op_sel_hi01 = (op_sel[1] << 1) | op_sel[0];
return *this;
}
u64 Get() {
return std::bit_cast<u64>(i);
}
private:
struct VOP3PInternal {
u64 vdst : 8;
u64 neg_hi : 3;
u64 op_sel : 3;
u64 op_sel_hi2 : 1;
u64 clmp : 1;
u64 op : 7;
u64 : 3;
u64 encoding : 6;
u64 src0 : 9;
u64 src1 : 9;
u64 src2 : 9;
u64 op_sel_hi01 : 2;
u64 neg : 3;
} i;
static_assert(sizeof(VOP3PInternal) == sizeof(u64));
};

View File

@ -4,6 +4,7 @@
#include <cmath>
#include <gtest/gtest.h>
#include <half.hpp>
#include "gcn_test_runner.hpp"
#include "instructions.hpp"
@ -58,3 +59,65 @@ TEST_F(GcnTest, add_nan) {
EXPECT_TRUE(result.has_value());
EXPECT_TRUE(std::isnan(*result));
}
using half = half_float::half;
struct F16x2 {
half a;
half b = half(0.0f);
bool operator==(const F16x2& rhs) const = default;
};
static_assert(sizeof(F16x2) == sizeof(float));
TEST_F(GcnTest, add_f16) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP2(OpcodeVOP2::V_ADD_F16, VOperand8::V0, SOperand9::V0, VOperand8::V1).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f)}, F16x2{half(1.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, F16x2{half(2.0f)});
}
TEST_F(GcnTest, add_f16_clamp) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP3A(OpcodeVOP3::V_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetClamp(true).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f)}, F16x2{half(1.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, F16x2{half(1.0f)}); //confirmed with neo
}
TEST_F(GcnTest, add_f16_neg) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP3A(OpcodeVOP3::V_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNeg({true, true, false}).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f)}, F16x2{half(1.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ((*result).a, half(-2.0f)); //confirmed with neo
}
TEST_F(GcnTest, add_f16_opsel_hi) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP3A(OpcodeVOP3::V_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetOpSel({true, true, false, true}).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(1.0f), half(2.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ((*result).a, half(1.0f));
EXPECT_EQ((*result).b, half(4.0f));
}
TEST_F(GcnTest, sub_f16) {
auto runner = gcn_test::Runner::instance().value();
auto spirv = TranslateToSpirv(VOP2(OpcodeVOP2::V_SUB_F16, VOperand8::V0, SOperand9::V0, VOperand8::V1).Get());
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(0.0f)}, F16x2{half(1.0f)}});
EXPECT_TRUE(result.has_value());
EXPECT_EQ(*result, F16x2{half(-1.0f)}); //confirmed with neo
}