frontend: Add helper methods for thread bit getters and setters (#4243)

Co-authored-by: georgemoralis <giorgosmrls@gmail.com>
This commit is contained in:
TheTurtle 2026-04-09 23:32:21 +03:00 committed by GitHub
parent 0d3b6f7dd0
commit 1f50aa3172
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 63 additions and 243 deletions

View File

@ -311,33 +311,10 @@ void Translator::S_CSELECT_B32(const GcnInst& inst) {
}
void Translator::S_CSELECT_B64(const GcnInst& inst) {
const auto get_src = [&](const InstOperand& operand) {
switch (operand.field) {
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ExecLo:
return ir.GetExec();
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(operand.code));
case OperandField::ConstZero:
return ir.Imm1(false);
default:
UNREACHABLE();
}
};
const IR::U1 src0{get_src(inst.src[0])};
const IR::U1 src1{get_src(inst.src[1])};
const IR::U1 src0{GetSrc1(inst.src[0])};
const IR::U1 src1{GetSrc1(inst.src[1])};
const IR::U1 result{ir.Select(ir.GetScc(), src0, src1)};
switch (inst.dst[0].field) {
case OperandField::VccLo:
ir.SetVcc(result);
break;
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[0].code), result);
break;
default:
UNREACHABLE();
}
SetDst1(inst.dst[0], result);
}
void Translator::S_AND_B32(NegateMode negate, const GcnInst& inst) {
@ -355,30 +332,8 @@ void Translator::S_AND_B32(NegateMode negate, const GcnInst& inst) {
}
void Translator::S_AND_B64(NegateMode negate, const GcnInst& inst) {
const auto get_src = [&](const InstOperand& operand) {
switch (operand.field) {
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ExecLo:
return ir.GetExec();
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(operand.code));
case OperandField::ConstZero:
return ir.Imm1(false);
case OperandField::SignedConstIntNeg:
ASSERT_MSG(-s32(operand.code) + SignedConstIntNegMin - 1 == -1,
"SignedConstIntNeg must be -1");
return ir.Imm1(true);
case OperandField::LiteralConst:
ASSERT_MSG(operand.code == 0 || operand.code == std::numeric_limits<u32>::max(),
"Unsupported literal {:#x}", operand.code);
return ir.Imm1(operand.code & 1);
default:
UNREACHABLE();
}
};
const IR::U1 src0{get_src(inst.src[0])};
IR::U1 src1{get_src(inst.src[1])};
const IR::U1 src0{GetSrc1(inst.src[0])};
IR::U1 src1{GetSrc1(inst.src[1])};
if (negate == NegateMode::Src1) {
src1 = ir.LogicalNot(src1);
}
@ -387,19 +342,7 @@ void Translator::S_AND_B64(NegateMode negate, const GcnInst& inst) {
result = ir.LogicalNot(result);
}
ir.SetScc(result);
switch (inst.dst[0].field) {
case OperandField::VccLo:
ir.SetVcc(result);
break;
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[0].code), result);
break;
case OperandField::ExecLo:
ir.SetExec(result);
break;
default:
UNREACHABLE();
}
SetDst1(inst.dst[0], result);
}
void Translator::S_OR_B32(const GcnInst& inst) {
@ -411,21 +354,8 @@ void Translator::S_OR_B32(const GcnInst& inst) {
}
void Translator::S_OR_B64(NegateMode negate, bool is_xor, const GcnInst& inst) {
const auto get_src = [&](const InstOperand& operand) {
switch (operand.field) {
case OperandField::ExecLo:
return ir.GetExec();
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(operand.code));
default:
UNREACHABLE();
}
};
const IR::U1 src0{get_src(inst.src[0])};
IR::U1 src1{get_src(inst.src[1])};
const IR::U1 src0{GetSrc1(inst.src[0])};
IR::U1 src1{GetSrc1(inst.src[1])};
if (negate == NegateMode::Src1) {
src1 = ir.LogicalNot(src1);
}
@ -434,16 +364,7 @@ void Translator::S_OR_B64(NegateMode negate, bool is_xor, const GcnInst& inst) {
result = ir.LogicalNot(result);
}
ir.SetScc(result);
switch (inst.dst[0].field) {
case OperandField::VccLo:
ir.SetVcc(result);
break;
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[0].code), result);
break;
default:
UNREACHABLE();
}
SetDst1(inst.dst[0], result);
}
void Translator::S_XOR_B32(const GcnInst& inst) {
@ -604,66 +525,14 @@ void Translator::S_MOV_B64(const GcnInst& inst) {
ir.SetScalarReg(IR::ScalarReg(inst.dst[0].code + 1),
ir.GetScalarReg(IR::ScalarReg(inst.src[0].code + 1)));
}
const IR::U1 src = [&] {
switch (inst.src[0].field) {
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ExecLo:
return ir.GetExec();
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(inst.src[0].code));
case OperandField::ConstZero:
return ir.Imm1(false);
default:
UNREACHABLE();
}
}();
switch (inst.dst[0].field) {
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[0].code), src);
break;
case OperandField::ExecLo:
ir.SetExec(src);
break;
case OperandField::VccLo:
ir.SetVcc(src);
break;
default:
UNREACHABLE();
}
SetDst1(inst.dst[0], GetSrc1(inst.src[0]));
}
void Translator::S_NOT_B64(const GcnInst& inst) {
const auto get_src = [&](const InstOperand& operand) {
switch (operand.field) {
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ExecLo:
return ir.GetExec();
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(operand.code));
case OperandField::ConstZero:
return ir.Imm1(false);
default:
UNREACHABLE();
}
};
const IR::U1 src0{get_src(inst.src[0])};
const IR::U1 src0{GetSrc1(inst.src[0])};
const IR::U1 result = ir.LogicalNot(src0);
ir.SetScc(result);
switch (inst.dst[0].field) {
case OperandField::VccLo:
ir.SetVcc(result);
break;
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[0].code), result);
break;
case OperandField::ExecLo:
ir.SetExec(result);
break;
default:
UNREACHABLE();
}
SetDst1(inst.dst[0], result);
}
void Translator::S_BREV_B32(const GcnInst& inst) {
@ -689,21 +558,7 @@ void Translator::S_FF1_I32_B32(const GcnInst& inst) {
}
void Translator::S_FF1_I32_B64(const GcnInst& inst) {
const auto src = [&] {
switch (inst.src[0].field) {
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(inst.src[0].code));
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ExecLo:
return ir.GetExec();
default:
UNREACHABLE_MSG("unhandled operand type {}", magic_enum::enum_name(inst.src[0].field));
}
}();
const IR::U32 result{ir.BallotFindLsb(ir.Ballot(src))};
SetDst(inst.dst[0], result);
SetDst(inst.dst[0], ir.BallotFindLsb(ir.Ballot(GetSrc1(inst.src[0]))));
}
void Translator::S_FLBIT_I32_B32(const GcnInst& inst) {
@ -740,29 +595,8 @@ void Translator::S_SAVEEXEC_B64(NegateMode negate, bool is_or, const GcnInst& in
// However here we flatten it to 1-bit EXEC and 1-bit VCC. For the destination
// SGPR we have a special IR opcode for SPGRs that act as thread masks.
IR::U1 exec{ir.GetExec()};
const IR::U1 src = [&] {
switch (inst.src[0].field) {
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(inst.src[0].code));
case OperandField::ExecLo:
return ir.GetExec();
default:
UNREACHABLE();
}
}();
switch (inst.dst[0].field) {
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[0].code), exec);
break;
case OperandField::VccLo:
ir.SetVcc(exec);
break;
default:
UNREACHABLE();
}
const IR::U1 src{GetSrc1(inst.src[0])};
SetDst1(inst.dst[0], exec);
// Update EXEC.
if (negate == NegateMode::Src1) {

View File

@ -255,6 +255,29 @@ IR::VectorReg Translator::GetScratchVgpr(u32 offset) {
return it->second;
};
IR::U1 Translator::GetSrc1(const InstOperand& operand) {
switch (operand.field) {
case OperandField::VccLo:
return ir.GetVcc();
case OperandField::ExecLo:
return ir.GetExec();
case OperandField::ScalarGPR:
return ir.GetThreadBitScalarReg(IR::ScalarReg(operand.code));
case OperandField::ConstZero:
return ir.Imm1(false);
case OperandField::SignedConstIntNeg:
ASSERT_MSG(-s32(operand.code) + SignedConstIntNegMin - 1 == -1,
"SignedConstIntNeg must be -1");
return ir.Imm1(true);
case OperandField::LiteralConst:
ASSERT_MSG(operand.code == 0 || operand.code == std::numeric_limits<u32>::max(),
"Unsupported literal {:#x}", operand.code);
return ir.Imm1(operand.code & 1);
default:
UNREACHABLE_MSG("Unknown field {}", u32(operand.field));
}
}
template <typename T>
T Translator::GetSrc(const InstOperand& operand) {
constexpr bool is_float = std::is_same_v<T, IR::F32>;
@ -477,6 +500,22 @@ T Translator::GetSrc64(const InstOperand& operand) {
template IR::U64 Translator::GetSrc64<IR::U64>(const InstOperand&);
template IR::F64 Translator::GetSrc64<IR::F64>(const InstOperand&);
void Translator::SetDst1(const InstOperand& operand, const IR::U1& value) {
switch (operand.field) {
case OperandField::VccLo:
ir.SetVcc(value);
break;
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(operand.code), value);
break;
case OperandField::ExecLo:
ir.SetExec(value);
break;
default:
UNREACHABLE_MSG("Unknown field {}", u32(operand.field));
}
}
void Translator::SetDst(const InstOperand& operand, const IR::U32F32& value) {
IR::U32F32 result = value;
if (value.Type() == IR::Type::F32) {

View File

@ -307,10 +307,12 @@ public:
void IMAGE_GET_LOD(const GcnInst& inst);
private:
IR::U1 GetSrc1(const InstOperand& operand);
template <typename T = IR::U32>
[[nodiscard]] T GetSrc(const InstOperand& operand);
template <typename T = IR::U64>
[[nodiscard]] T GetSrc64(const InstOperand& operand);
void SetDst1(const InstOperand& operand, const IR::U1& value);
void SetDst(const InstOperand& operand, const IR::U32F32& value);
void SetDst64(const InstOperand& operand, const IR::U64F64& value_raw);

View File

@ -1056,17 +1056,7 @@ void Translator::V_CMP_F32(ConditionOp op, bool set_exec, const GcnInst& inst) {
if (set_exec) {
ir.SetExec(result);
}
switch (inst.dst[1].field) {
case OperandField::VccLo:
ir.SetVcc(result);
break;
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[1].code), result);
break;
default:
UNREACHABLE();
}
SetDst1(inst.dst[1], result);
}
void Translator::V_CMP_F64(ConditionOp op, bool set_exec, const GcnInst& inst) {
@ -1097,17 +1087,7 @@ void Translator::V_CMP_F64(ConditionOp op, bool set_exec, const GcnInst& inst) {
if (set_exec) {
ir.SetExec(result);
}
switch (inst.dst[1].field) {
case OperandField::VccLo:
ir.SetVcc(result);
break;
case OperandField::ScalarGPR:
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[1].code), result);
break;
default:
UNREACHABLE();
}
SetDst1(inst.dst[1], result);
}
void Translator::V_CMP_U32(ConditionOp op, bool is_signed, bool set_exec, const GcnInst& inst) {
@ -1138,14 +1118,7 @@ void Translator::V_CMP_U32(ConditionOp op, bool is_signed, bool set_exec, const
if (set_exec) {
ir.SetExec(result);
}
switch (inst.dst[1].field) {
case OperandField::VccLo:
return ir.SetVcc(result);
case OperandField::ScalarGPR:
return ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[0].code), result);
default:
UNREACHABLE();
}
SetDst1(inst.dst[1], result);
}
void Translator::V_CMP_U64(ConditionOp op, bool is_signed, bool set_exec, const GcnInst& inst) {
@ -1187,15 +1160,7 @@ void Translator::V_CMP_U64(ConditionOp op, bool is_signed, bool set_exec, const
if (set_exec) {
UNREACHABLE_MSG("Exec setting for V_CMP_U64 is not supported");
}
switch (inst.dst[1].field) {
case OperandField::VccLo:
return ir.SetVcc(result);
case OperandField::ScalarGPR:
return ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[1].code), result);
default:
UNREACHABLE();
}
SetDst1(inst.dst[1], result);
}
void Translator::V_CMP_CLASS_F32(const GcnInst& inst) {
@ -1217,15 +1182,7 @@ void Translator::V_CMP_CLASS_F32(const GcnInst& inst) {
// We don't know the type yet, delay its resolution.
value = ir.FPCmpClass32(src0, src1);
}
switch (inst.dst[1].field) {
case OperandField::VccLo:
return ir.SetVcc(value);
case OperandField::ScalarGPR:
return ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[1].code), value);
default:
UNREACHABLE();
}
SetDst1(inst.dst[1], value);
}
// VOP3a
@ -1527,13 +1484,7 @@ void Translator::V_MAD_U64_U32(const GcnInst& inst) {
IR::U32 Translator::GetCarryIn(const GcnInst& inst) {
IR::U1 carry;
if (inst.src_count == 3) { // VOP3
if (inst.src[2].field == OperandField::VccLo) {
carry = ir.GetVcc();
} else if (inst.src[2].field == OperandField::ScalarGPR) {
carry = ir.GetThreadBitScalarReg(IR::ScalarReg(inst.src[2].code));
} else {
UNREACHABLE();
}
carry = GetSrc1(inst.src[2]);
} else { // VOP2
carry = ir.GetVcc();
}
@ -1543,13 +1494,7 @@ IR::U32 Translator::GetCarryIn(const GcnInst& inst) {
void Translator::SetCarryOut(const GcnInst& inst, const IR::U1& carry) {
if (inst.dst_count == 2) { // VOP3
if (inst.dst[1].field == OperandField::VccLo) {
ir.SetVcc(carry);
} else if (inst.dst[1].field == OperandField::ScalarGPR) {
ir.SetThreadBitScalarReg(IR::ScalarReg(inst.dst[1].code), carry);
} else {
UNREACHABLE();
}
SetDst1(inst.dst[1], carry);
} else { // VOP2
ir.SetVcc(carry);
}