Neo: three-operand min/max/med for 16bit floats (#4395)

This commit is contained in:
Marcin Mikołajczyk 2026-05-11 23:06:00 +02:00 committed by GitHub
parent 201aaf232b
commit 89b886348e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 3 deletions

View File

@ -2283,7 +2283,9 @@ constexpr std::array<InstFormat, 883> InstructionFormatVOP3 = {
ScalarType::Float16},
{},
{},
{},
// 855 = V_MED3_F16
{InstClass::VectorFpArith16, InstCategory::VectorALU, 3, 1, ScalarType::Float16,
ScalarType::Float16},
{},
{},
{},

View File

@ -709,6 +709,7 @@ enum class OpcodeVOP3 : u32 {
V_ADD_LSHL_U32 = 839,
V_MIN3_F16 = 849,
V_MAX3_F16 = 852,
V_MED3_F16 = 855,
V_MAD_I16 = 862,
V_ADD3_U32 = 877,
V_LSHL_OR_B32 = 879,
@ -2020,6 +2021,7 @@ enum class Opcode : u32 {
V_ADD_LSHL_U32 = 839 + (u32)OpcodeMap::OP_MAP_VOP3,
V_MIN3_F16 = 849 + (u32)OpcodeMap::OP_MAP_VOP3,
V_MAX3_F16 = 852 + (u32)OpcodeMap::OP_MAP_VOP3,
V_MED3_F16 = 855 + (u32)OpcodeMap::OP_MAP_VOP3,
V_MAD_I16 = 862 + (u32)OpcodeMap::OP_MAP_VOP3,
V_ADD3_U32 = 877 + (u32)OpcodeMap::OP_MAP_VOP3,
V_LSHL_OR_B32 = 879 + (u32)OpcodeMap::OP_MAP_VOP3,

View File

@ -283,6 +283,9 @@ public:
void V_LSHLREV_B16(const GcnInst& inst);
void V_LSHL_ADD_U32(const GcnInst& inst);
void V_ADD_LSHL_U32(const GcnInst& inst);
void V_MIN3_F16(const GcnInst& inst);
void V_MAX3_F16(const GcnInst& inst);
void V_MED3_F16(const GcnInst& inst);
void V_ADD3_U32(const GcnInst& inst);
void V_LSHL_OR_B32(const GcnInst& inst);
void V_AND_OR_B32(const GcnInst& inst);

View File

@ -462,12 +462,18 @@ void Translator::EmitVectorAlu(const GcnInst& inst) {
return V_ASHRREV_I16(inst);
case Opcode::V_LSHLREV_B16:
return V_LSHLREV_B16(inst);
case Opcode::V_ADD3_U32:
return V_ADD3_U32(inst);
case Opcode::V_ADD_LSHL_U32:
return V_ADD_LSHL_U32(inst);
case Opcode::V_LSHL_ADD_U32:
return V_LSHL_ADD_U32(inst);
case Opcode::V_MIN3_F16:
return V_MIN3_F16(inst);
case Opcode::V_MAX3_F16:
return V_MAX3_F16(inst);
case Opcode::V_MED3_F16:
return V_MED3_F16(inst);
case Opcode::V_ADD3_U32:
return V_ADD3_U32(inst);
case Opcode::V_LSHL_OR_B32:
return V_LSHL_OR_B32(inst);
case Opcode::V_AND_OR_B32:
@ -1636,6 +1642,36 @@ void Translator::V_LSHL_ADD_U32(const GcnInst& inst) {
SetDst(inst.dst[0], result);
}
void Translator::V_MIN3_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto src2 = GetSrc16<IR::F32>(inst.src[2]);
const auto result = ir.FPMinTri(src0, src1, src2);
SetDst16(inst.dst[0], result);
}
void Translator::V_MAX3_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto src2 = GetSrc16<IR::F32>(inst.src[2]);
const auto result = ir.FPMaxTri(src0, src1, src2);
SetDst16(inst.dst[0], result);
}
void Translator::V_MED3_F16(const GcnInst& inst) {
const auto src0 = GetSrc16<IR::F32>(inst.src[0]);
const auto src1 = GetSrc16<IR::F32>(inst.src[1]);
const auto src2 = GetSrc16<IR::F32>(inst.src[2]);
const auto result = ir.FPMedTri(src0, src1, src2);
SetDst16(inst.dst[0], result);
}
void Translator::V_ADD3_U32(const GcnInst& inst) {
const auto src0 = GetSrc<IR::U32>(inst.src[0]);
const auto src1 = GetSrc<IR::U32>(inst.src[1]);