mirror of
https://github.com/shadps4-emu/shadPS4.git
synced 2026-05-12 16:09:42 -06:00
Neo: Float16 packed math (#4354)
This commit is contained in:
parent
26eaa3e3af
commit
34b35b526e
@ -768,6 +768,134 @@ T Translator::GetSrc64(const InstOperand& operand) {
|
||||
template IR::U64 Translator::GetSrc64<IR::U64>(const InstOperand&);
|
||||
template IR::F64 Translator::GetSrc64<IR::F64>(const InstOperand&);
|
||||
|
||||
template <typename T>
|
||||
pk_type<T> Translator::GetSrcPk(const InstOperand& operand) {
|
||||
constexpr bool is_float = std::is_same_v<T, IR::F32>;
|
||||
|
||||
const auto get_imm = [&](auto value) -> pk_type<T> {
|
||||
if constexpr (is_float) {
|
||||
auto imm = ir.Imm32(std::bit_cast<float>(value));
|
||||
return {operand.op_sel.op_sel ? ir.Imm32(0.f) : imm,
|
||||
operand.op_sel.op_sel_hi ? ir.Imm32(0.f) : imm};
|
||||
} else {
|
||||
auto imm = ir.Imm32(std::bit_cast<u32>(value));
|
||||
return {operand.op_sel.op_sel ? ir.Imm32(0U) : imm,
|
||||
operand.op_sel.op_sel_hi ? ir.Imm32(0U) : imm};
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto number_format = [&]() {
|
||||
if constexpr (is_float) {
|
||||
return AmdGpu::NumberFormat::Float;
|
||||
} else {
|
||||
return AmdGpu::NumberFormat::Uint;
|
||||
}
|
||||
}();
|
||||
|
||||
const auto cast = [&](auto value) -> T {
|
||||
if constexpr (is_float) {
|
||||
return value;
|
||||
} else {
|
||||
return ir.BitCast<IR::U32>(value);
|
||||
}
|
||||
};
|
||||
|
||||
const auto extract = [&](auto value) -> pk_type<T> {
|
||||
auto v_unpacked = ir.Unpack2x16(number_format, value);
|
||||
return {cast(IR::F32{ir.CompositeExtract(v_unpacked, operand.op_sel.op_sel)}),
|
||||
cast(IR::F32{ir.CompositeExtract(v_unpacked, operand.op_sel.op_sel_hi)})};
|
||||
};
|
||||
|
||||
pk_type<T> value{};
|
||||
switch (operand.field) {
|
||||
case OperandField::ScalarGPR: {
|
||||
value = extract(ir.GetScalarReg<IR::U32>(IR::ScalarReg(operand.code)));
|
||||
break;
|
||||
}
|
||||
case OperandField::VectorGPR: {
|
||||
value = extract(ir.GetVectorReg<IR::U32>(IR::VectorReg(operand.code)));
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstZero: {
|
||||
value = get_imm(0U);
|
||||
break;
|
||||
}
|
||||
case OperandField::SignedConstIntPos: {
|
||||
value = get_imm(operand.code - SignedConstIntPosMin + 1);
|
||||
break;
|
||||
}
|
||||
case OperandField::SignedConstIntNeg: {
|
||||
value = get_imm(-s32(operand.code) + SignedConstIntNegMin - 1);
|
||||
break;
|
||||
}
|
||||
case OperandField::LiteralConst: {
|
||||
value = get_imm(operand.code);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatPos_1_0: {
|
||||
value = get_imm(1.f);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatPos_0_5: {
|
||||
value = get_imm(0.5f);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatPos_2_0: {
|
||||
value = get_imm(2.0f);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatPos_4_0: {
|
||||
value = get_imm(4.0f);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatNeg_0_5: {
|
||||
value = get_imm(-0.5f);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatNeg_1_0: {
|
||||
value = get_imm(-1.0f);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatNeg_2_0: {
|
||||
value = get_imm(-2.0f);
|
||||
break;
|
||||
}
|
||||
case OperandField::ConstFloatNeg_4_0: {
|
||||
value = get_imm(-4.0f);
|
||||
break;
|
||||
}
|
||||
case OperandField::Inv2Pi: {
|
||||
value = get_imm(1.0f / (2.0f * std::numbers::pi_v<float>));
|
||||
break;
|
||||
}
|
||||
case OperandField::VccLo:
|
||||
value = extract(ir.GetVccLo());
|
||||
break;
|
||||
default:
|
||||
UNREACHABLE_MSG("unexpected operand: {}", std::to_underlying(operand.field));
|
||||
}
|
||||
|
||||
if constexpr (is_float) {
|
||||
if (operand.input_modifier.neg) {
|
||||
value.first = ir.FPNeg(value.first);
|
||||
}
|
||||
if (operand.input_modifier.neg_hi) {
|
||||
value.second = ir.FPNeg(value.second);
|
||||
}
|
||||
} else {
|
||||
if (operand.input_modifier.neg) {
|
||||
value.first = ir.INeg(value.first);
|
||||
}
|
||||
if (operand.input_modifier.neg_hi) {
|
||||
value.second = ir.INeg(value.second);
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
template pk_type<IR::U32> Translator::GetSrcPk<IR::U32>(const InstOperand&);
|
||||
template pk_type<IR::F32> Translator::GetSrcPk<IR::F32>(const InstOperand&);
|
||||
|
||||
void Translator::SetDst1(const InstOperand& operand, const IR::U1& value) {
|
||||
switch (operand.field) {
|
||||
case OperandField::VccLo:
|
||||
|
||||
@ -54,6 +54,9 @@ enum class NegateMode : u32 {
|
||||
Result,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using pk_type = std::pair<T, T>;
|
||||
|
||||
static constexpr size_t MaxInterpVgpr = 16;
|
||||
|
||||
class Translator {
|
||||
@ -286,6 +289,11 @@ public:
|
||||
void V_OR3_B32(const GcnInst& inst);
|
||||
|
||||
// VOP3P
|
||||
void V_PK_FMA_F16(const GcnInst& inst);
|
||||
void V_PK_ADD_F16(const GcnInst& inst);
|
||||
void V_PK_MUL_F16(const GcnInst& inst);
|
||||
void V_PK_MIN_F16(const GcnInst& inst);
|
||||
void V_PK_MAX_F16(const GcnInst& inst);
|
||||
void V_MAD_MIX_F32(const GcnInst& inst);
|
||||
void V_MAD_MIXLO_F16(const GcnInst& inst);
|
||||
void V_MAD_MIXHI_F16(const GcnInst& inst);
|
||||
@ -334,6 +342,8 @@ private:
|
||||
template <typename T = IR::U64>
|
||||
[[nodiscard]] T GetSrc64(const InstOperand& operand);
|
||||
[[nodiscard]] IR::F32 GetSrcMix(const InstOperand& operand);
|
||||
template <typename T = IR::U32>
|
||||
[[nodiscard]] pk_type<T> GetSrcPk(const InstOperand& operand);
|
||||
void SetDst1(const InstOperand& operand, const IR::U1& value);
|
||||
void SetDst(const InstOperand& operand, const IR::U32F32& value);
|
||||
template <bool is_signed = false>
|
||||
|
||||
@ -478,6 +478,16 @@ void Translator::EmitVectorAlu(const GcnInst& inst) {
|
||||
return;
|
||||
|
||||
// VOP3P
|
||||
case Opcode::V_PK_FMA_F16:
|
||||
return V_PK_FMA_F16(inst);
|
||||
case Opcode::V_PK_ADD_F16:
|
||||
return V_PK_ADD_F16(inst);
|
||||
case Opcode::V_PK_MUL_F16:
|
||||
return V_PK_MUL_F16(inst);
|
||||
case Opcode::V_PK_MIN_F16:
|
||||
return V_PK_MIN_F16(inst);
|
||||
case Opcode::V_PK_MAX_F16:
|
||||
return V_PK_MAX_F16(inst);
|
||||
case Opcode::V_MAD_MIX_F32:
|
||||
return V_MAD_MIX_F32(inst);
|
||||
case Opcode::V_MAD_MIXLO_F16:
|
||||
@ -1634,6 +1644,62 @@ void Translator::V_ADD3_U32(const GcnInst& inst) {
|
||||
SetDst(inst.dst[0], ir.IAdd(src0, ir.IAdd(src1, src2)));
|
||||
}
|
||||
|
||||
void Translator::V_PK_FMA_F16(const GcnInst& inst) {
|
||||
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
|
||||
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
|
||||
const auto src2 = GetSrcPk<IR::F32>(inst.src[2]);
|
||||
|
||||
const auto result_lo = ir.FPFma(src0.first, src1.first, src2.first);
|
||||
const auto result_hi = ir.FPFma(src0.second, src1.second, src2.second);
|
||||
|
||||
SetDst(inst.dst[0],
|
||||
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
|
||||
}
|
||||
|
||||
void Translator::V_PK_ADD_F16(const GcnInst& inst) {
|
||||
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
|
||||
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
|
||||
|
||||
const auto result_lo = ir.FPAdd(src0.first, src1.first);
|
||||
const auto result_hi = ir.FPAdd(src0.second, src1.second);
|
||||
|
||||
SetDst(inst.dst[0],
|
||||
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
|
||||
}
|
||||
|
||||
void Translator::V_PK_MUL_F16(const GcnInst& inst) {
|
||||
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
|
||||
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
|
||||
|
||||
const auto result_lo = ir.FPMul(src0.first, src1.first);
|
||||
const auto result_hi = ir.FPMul(src0.second, src1.second);
|
||||
|
||||
SetDst(inst.dst[0],
|
||||
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
|
||||
}
|
||||
|
||||
void Translator::V_PK_MIN_F16(const GcnInst& inst) {
|
||||
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
|
||||
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
|
||||
|
||||
const auto result_lo = ir.FPMin(src0.first, src1.first);
|
||||
const auto result_hi = ir.FPMin(src0.second, src1.second);
|
||||
|
||||
SetDst(inst.dst[0],
|
||||
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
|
||||
}
|
||||
|
||||
void Translator::V_PK_MAX_F16(const GcnInst& inst) {
|
||||
const auto src0 = GetSrcPk<IR::F32>(inst.src[0]);
|
||||
const auto src1 = GetSrcPk<IR::F32>(inst.src[1]);
|
||||
|
||||
const auto result_lo = ir.FPMax(src0.first, src1.first);
|
||||
const auto result_hi = ir.FPMax(src0.second, src1.second);
|
||||
|
||||
SetDst(inst.dst[0],
|
||||
ir.Pack2x16(AmdGpu::NumberFormat::Float, ir.CompositeConstruct(result_lo, result_hi)));
|
||||
}
|
||||
|
||||
void Translator::V_LSHL_OR_B32(const GcnInst& inst) {
|
||||
const auto src0 = GetSrc<IR::U32>(inst.src[0]);
|
||||
const auto src1 = GetSrc<IR::U32>(inst.src[1]);
|
||||
|
||||
@ -424,3 +424,87 @@ TEST_F(GcnTest, ashrrev_i16_1) {
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, 0x1234F111);
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_1) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).Get());
|
||||
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, (F16x2{half(4.0f), half(6.0f)}));
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_2) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).Get();
|
||||
auto spirv = TranslateToSpirv(inst);
|
||||
auto result = runner->run<u32>(spirv, 0U);
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, 0x00003118);
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_3) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).SetOpSel({0,1,1}).Get();
|
||||
auto spirv = TranslateToSpirv(inst);
|
||||
auto result = runner->run<u32>(spirv, 0U);
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, 0);
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_4) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0p5, SOperand9::Const0p5).Get();
|
||||
auto spirv = TranslateToSpirv(inst);
|
||||
auto result = runner->run<u32>(spirv, 0U);
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, 0x3C00);
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_5) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto inst = VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::Const0, SOperand9::ConstInv2Pi).SetOpSelHi({0,0,0}).Get();
|
||||
auto spirv = TranslateToSpirv(inst);
|
||||
auto result = runner->run<u32>(spirv, 0U);
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, 0x31183118);
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_neg_lo) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNeg({1,1,0}).Get());
|
||||
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, (F16x2{half(-4.0f), half(6.0f)}));
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_neg_hi) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetNegHi({1,1,0}).Get());
|
||||
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, (F16x2{half(4.0f), half(-6.0f)}));
|
||||
}
|
||||
|
||||
TEST_F(GcnTest, pk_add_f16_op_sel_reversed) {
|
||||
auto runner = gcn_test::Runner::instance().value();
|
||||
|
||||
auto spirv = TranslateToSpirv(VOP3P(OpcodeVOP3P::V_PK_ADD_F16, VOperand8::V0, SOperand9::V0, SOperand9::V1).SetOpSel({1,1,1}).SetOpSelHi({0,0,0}).Get());
|
||||
auto result = runner->run<F16x2>(spirv, std::array{F16x2{half(1.0f), half(2.0f)}, F16x2{half(3.0f), half(4.0f)}});
|
||||
|
||||
EXPECT_TRUE(result.has_value());
|
||||
EXPECT_EQ(*result, (F16x2{half(6.0f), half(4.0f)}));
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user