drm/amdkfd: Apply VGPR bank state fixup on gfx12.1 trap exit

- Identify co-issue of S_SET_VGPR_MSB and VALU with banked VGPR
- Restore previous bank setting when exiting the trap

v2:
- Refine VOP3PX2 detection
- Improve load pipelining
- Fix a comment typo

Signed-off-by: Jay Cornwall <jay.cornwall@amd.com>
Reviewed-by: Lancelot Six <lancelot.six@amd.com>
Cc: Joseph Greathouse <joseph.greathouse@amd.com>
Signed-off-by: Alex Deucher <alexander.deucher@amd.com>
This commit is contained in:
Jay Cornwall
2025-10-23 15:33:04 -05:00
committed by Alex Deucher
parent 1005ab86cf
commit ba80939fec
2 changed files with 712 additions and 421 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -73,6 +73,7 @@ var SQ_WAVE_LDS_ALLOC_GRANULARITY = 10
#endif
var SQ_WAVE_EXCP_FLAG_PRIV_ADDR_WATCH_MASK = 0xF
var SQ_WAVE_EXCP_FLAG_PRIV_MEM_VIOL_SHIFT = 4
var SQ_WAVE_EXCP_FLAG_PRIV_MEM_VIOL_MASK = 0x10
var SQ_WAVE_EXCP_FLAG_PRIV_SAVE_CONTEXT_SHIFT = 5
var SQ_WAVE_EXCP_FLAG_PRIV_SAVE_CONTEXT_MASK = 0x20
@@ -362,6 +363,15 @@ L_TRAP_CASE:
L_EXIT_TRAP:
s_and_b32 ttmp1, ttmp1, ADDRESS_HI32_MASK
#if HAVE_BANKED_VGPRS
s_getreg_b32 s_save_excp_flag_priv, hwreg(HW_REG_WAVE_EXCP_FLAG_PRIV)
fixup_vgpr_bank_selection()
#endif
#if HAVE_XNACK
restore_xnack_state_priv(s_save_tmp)
#endif
// Restore SQ_WAVE_STATUS.
s_and_b64 exec, exec, exec // Restore STATUS.EXECZ, not writable by s_setreg_b32
s_and_b64 vcc, vcc, vcc // Restore STATUS.VCCZ, not writable by s_setreg_b32
@@ -390,6 +400,14 @@ L_HAVE_VGPRS:
s_mov_b32 s_save_tmp, 0
s_setreg_b32 hwreg(HW_REG_WAVE_EXCP_FLAG_PRIV, SQ_WAVE_EXCP_FLAG_PRIV_SAVE_CONTEXT_SHIFT, 1), s_save_tmp //clear saveCtx bit
#if HAVE_XNACK
save_and_clear_xnack_state_priv(s_save_tmp)
#endif
#if HAVE_BANKED_VGPRS
fixup_vgpr_bank_selection()
#endif
/* inform SPI the readiness and wait for SPI's go signal */
s_mov_b32 s_save_exec_lo, exec_lo //save EXEC and use EXEC for the go signal from SPI
s_mov_b32 s_save_exec_hi, exec_hi
@@ -404,7 +422,6 @@ L_HAVE_VGPRS:
s_or_b32 s_save_pc_hi, s_save_pc_hi, s_save_tmp
#if HAVE_XNACK
save_and_clear_xnack_state_priv(s_save_tmp)
s_getreg_b32 s_save_xnack_mask, hwreg(HW_REG_WAVE_XNACK_MASK)
s_setreg_imm32_b32 hwreg(HW_REG_WAVE_XNACK_MASK), 0
#endif
@@ -1328,3 +1345,150 @@ L_BARRIER_RESTORE_LOOP:
L_BARRIER_RESTORE_DONE:
end
#if HAVE_BANKED_VGPRS
function fixup_vgpr_bank_selection
// PC read may fault if memory violation has been asserted.
// In this case no further progress is expected so fixup is not needed.
s_bitcmp1_b32 s_save_excp_flag_priv, SQ_WAVE_EXCP_FLAG_PRIV_MEM_VIOL_SHIFT
s_cbranch_scc1 L_FIXUP_DONE
// ttmp[0:1]: {7b'0} PC[56:0]
// ttmp2, 3, 10, 13, 14, 15: free
s_load_b64 [ttmp14, ttmp15], [ttmp0, ttmp1], 0 scope:SCOPE_CU // Load the 2 instruction DW we are returning to
s_load_b64 [ttmp2, ttmp3], [ttmp0, ttmp1], 8 scope:SCOPE_CU // Load the next 2 instruction DW, just in case
s_wait_kmcnt 1
s_and_b32 ttmp10, ttmp14, 0x80000000 // Check bit 31 in the first DWORD
// SCC set if ttmp10 is != 0, i.e. if bit 31 == 1
s_cbranch_scc1 L_FIXUP_NOT_VOP12C // If bit 31 is 1, we are not VOP1, VOP2, or VOP3C
// Fall through here means bit 31 == 0, meaning we are VOP1, VOP2, or VOPC
// Size of instruction depends on Opcode or SRC0_9
// Check for VOP2 opcode
s_bfe_u32 ttmp10, ttmp14, (25 | (6 << 0x10)) // Check bits 30:25 for VOP2 Opcode
// VOP2 V_FMAMK_F64 of V_FMAAK_F64 has implied 64-bit literature, 3 DW
s_sub_co_i32 ttmp13, ttmp10, 0x23 // V_FMAMK_F64 is 0x23, V_FMAAK_F64 is 0x24
s_cmp_le_u32 ttmp13, 0x1 // 0==0x23, 1==0x24
s_cbranch_scc1 L_FIXUP_THREE_DWORD // If either, this is 3 DWORD inst
// VOP2 V_FMAMK_F32, V_FMAAK_F32, V_FMAMK_F16, V_FMAAK_F16, 2 DW
s_sub_co_i32 ttmp13, ttmp10, 0x2c // V_FMAMK_F32 is 0x2c, V_FMAAK_F32 is 0x2d
s_cmp_le_u32 ttmp13, 0x1 // 0==0x2c, 1==0x2d
s_cbranch_scc1 L_FIXUP_TWO_DWORD // If either, this is 2 DWORD inst
s_sub_co_i32 ttmp13, ttmp10, 0x37 // V_FMAMK_F16 is 0x37, V_FMAAK_F16 is 0x38
s_cmp_le_u32 ttmp13, 0x1 // 0==0x37, 1==0x38
s_cbranch_scc1 L_FIXUP_TWO_DWORD // If either, this is 2 DWORD inst
// Check SRC0_9 for VOP1, VOP2, and VOPC
s_and_b32 ttmp10, ttmp14, 0x1ff // Check bits 8:0 for SRC0_9
// Literal constant 64 is 3 DWORDs
s_cmp_eq_u32 ttmp10, 0xfe // 0xfe == 254 == Literal constant64
s_cbranch_scc1 L_FIXUP_THREE_DWORD // 3 DWORD inst
// Literal constant 32, DPP16, DPP8, and DPP8FI are 2 DWORDs
s_cmp_eq_u32 ttmp10, 0xff // 0xff == 255 = Literal constant32
s_cbranch_scc1 L_FIXUP_TWO_DWORD // 2 DWORD inst
s_cmp_eq_u32 ttmp10, 0xfa // 0xfa == 250 = DPP16
s_cbranch_scc1 L_FIXUP_TWO_DWORD // 2 DWORD inst
s_sub_co_i32 ttmp13, ttmp10, 0xe9 // DPP8 is 0xe9, DPP8FI is 0xea
s_cmp_le_u32 ttmp13, 0x1 // 0==0xe9, 1==0xea
s_cbranch_scc1 L_FIXUP_TWO_DWORD // If either, this is 2 DWORD inst
// Instruction is 1 DWORD otherwise
L_FIXUP_ONE_DWORD:
// Check if TTMP15 contains the value for S_SET_VGPR_MSB instruction
s_and_b32 ttmp10, ttmp15, 0xffff0000 // Check encoding in upper 16 bits
s_cmp_eq_u32 ttmp10, 0xbf860000 // Check if SOPP (9b'10_1111111) and S_SET_VGPR_MSB (7b'0000110)
s_cbranch_scc0 L_FIXUP_DONE // No problem, no fixup needed
// VALU op followed by a S_SET_VGPR_MSB. Need to pull SIMM[15:8] to fix up MODE.*_VGPR_MSB
s_bfe_u32 ttmp10, ttmp15, (14 | (2 << 0x10)) // Shift SIMM[15:14] over to 1:0, Dst
s_and_b32 ttmp13, ttmp15, 0x3f00 // Mask to get SIMM[13:8] only
s_lshr_b32 ttmp13, ttmp13, 6 // Shift SIMM[13:8] into 7:2, Src2, Src1, Src0
s_or_b32 ttmp10, ttmp10, ttmp13 // Src2, Src1, Src0, Dst --> format in MODE register
s_setreg_b32 hwreg(HW_REG_WAVE_MODE, 12, 8), ttmp10 // Write value into MODE[19:12]
s_branch L_FIXUP_DONE
L_FIXUP_NOT_VOP12C:
// ttmp[0:1]: {7b'0} PC[56:0]
// ttmp2: PC+2 value (not waitcnt'ed yet)
// ttmp3: PC+3 value (not waitcnt'ed yet)
// ttmp10, ttmp13: free
// ttmp14: PC+O value
// ttmp15: PC+1 value
// Not VOP1, VOP2, or VOPC.
// Check if we are VOP3 or VOP3SD
s_and_b32 ttmp10, ttmp14, 0xfc000000 // Bits 31:26
s_cmp_eq_u32 ttmp10, 0xd4000000 // If 31:26 = 0x35, this is VOP3 or VOP3SD
s_cbranch_scc1 L_FIXUP_CHECK_VOP3 // If VOP3 or VOP3SD, need to check SRC2_9, SRC1_9, SRC0_9
// Not VOP1, VOP2, VOPC, VOP3, or VOP3SD.
// Check for VOPD
s_cmp_eq_u32 ttmp10, 0xc8000000 // If 31:26 = 0x32, this is VOPD
s_cbranch_scc1 L_FIXUP_CHECK_VOPD // If VOPD, need to check OpX, OpY, SRCX0 and SRCY0
// Not VOP1, VOP2, VOPC, VOP3, VOP3SD, VOPD.
// Check if we are VOPD3
s_and_b32 ttmp10, ttmp14, 0xff000000 // Bits 31:24
s_cmp_eq_u32 ttmp10, 0xcf000000 // If 31:24 = 0xcf, this is VOPD3
s_cbranch_scc1 L_FIXUP_THREE_DWORD // If VOPD3, 3 DWORD inst
// Not VOP1, VOP2, VOPC, VOP3, VOP3SD, VOPD, or VOPD3.
// Might be in VOP3P, but we must ensure we are not VOP3PX2
s_and_b32 ttmp13, ttmp14, 0xffff0000 // Bits 31:16
s_cmp_eq_u32 ttmp13, 0xcc350000 // If 31:16 = 0xcc35, this is VOP3PX2
s_cbranch_scc1 L_FIXUP_DONE // If VOP3PX2, no fixup needed
s_cmp_eq_u32 ttmp13, 0xcc3a0000 // If 31:16 = 0xcc3a, this is VOP3PX2
s_cbranch_scc1 L_FIXUP_DONE // If VOP3PX2, no fixup needed
// Check if we are VOP3P
s_cmp_eq_u32 ttmp10, 0xcc000000 // If 31:24 = 0xcc, this is VOP3P
s_cbranch_scc0 L_FIXUP_DONE // Not in VOP3P, so instruction is not VOP1, VOP2,
// VOPC, VOP3, VOP3SD, VOP3P, VOPD, or VOPD3
// No fixup needed.
// Fall-through if we are in VOP3P to check SRC2_9, SRC1_9, and SRC0_9
L_FIXUP_CHECK_VOP3:
// Start with Src0, which is in bits 8:0 of second instruction DW, ttmp15
s_and_b32 ttmp10, ttmp15, 0x1ff // Mask out unused bits
// Src0_9 == Literal constant 32, DPP16, DPP8, and DPP8FI means 3 DWORDs
s_cmp_eq_u32 ttmp10, 0xff // 0xff == 255 = Literal constant32
s_cbranch_scc1 L_FIXUP_THREE_DWORD // 3 DWORD inst
s_cmp_eq_u32 ttmp10, 0xfa // 0xfa == 250 = DPP16
s_cbranch_scc1 L_FIXUP_THREE_DWORD // 3 DWORD inst
s_sub_co_i32 ttmp10, ttmp10, 0xe9 // DPP8 is 0xe9, DPP8FI is 0xea
s_cmp_le_u32 ttmp10, 0x1 // 0==0xe9, 1==0xea
s_cbranch_scc1 L_FIXUP_THREE_DWORD // If either, this is 3 DWORD inst
s_and_b32 ttmp10, ttmp15, 0x3fe00 // Next is Src1, which is in 17:9
s_cmp_eq_u32 ttmp10, 0x1fe00 // 0xff == 255 = Literal constant32
s_cbranch_scc1 L_FIXUP_THREE_DWORD // 3 DWORD inst
s_and_b32 ttmp10, ttmp15, 0x7fc0000 // Next is Src2, which is in 26:18
s_cmp_eq_u32 ttmp10, 0x3fc0000 // 0xff == 255 = Literal constant32
s_cbranch_scc1 L_FIXUP_THREE_DWORD // 3 DWORD inst
s_branch L_FIXUP_TWO_DWORD // No special encodings, VOP3* is 2 Dword
L_FIXUP_CHECK_VOPD:
// OpX being V_DUAL_FMA*K_F32 means 3 DWORDs
s_bfe_u32 ttmp10, ttmp14, (22 | (4 << 0x10)) // OPX is bits 25:22
s_sub_co_i32 ttmp10, ttmp10, 0x1 // V_DUAL_FMAAK_F32 is 0x1, V_DUAL_FMAMK_F32 is 0x2
s_cmp_le_u32 ttmp10, 0x1 // 0==0x1, 1==0x2
s_cbranch_scc1 L_FIXUP_THREE_DWORD // If either, this is 3 DWORD inst
// OpY being V_DUAL_FMA*K_F32 means 3 DWORDs
s_bfe_u32 ttmp10, ttmp14, (17 | (5 << 0x10)) // OPX is bits 21:17
s_sub_co_i32 ttmp10, ttmp10, 0x1 // V_DUAL_FMAAK_F32 is 0x1, V_DUAL_FMAMK_F32 is 0x2
s_cmp_le_u32 ttmp10, 0x1 // 0==0x1, 1==0x2
s_cbranch_scc1 L_FIXUP_THREE_DWORD // If either, this is 3 DWORD inst
// SRCX0 == Literal constant 32 means 3 DWORDs
s_and_b32 ttmp10, ttmp14, 0x1ff // SRCX0 is in bits 8:0 of 1st DWORD
s_cmp_eq_u32 ttmp10, 0xff // 0xff == 255 = Literal constant32
s_cbranch_scc1 L_FIXUP_THREE_DWORD // 3 DWORD inst
// SRCY0 == Literal constant 32 means 3 DWORDs
s_and_b32 ttmp10, ttmp15, 0x1ff // SRCY0 is in bits 8:0 of 2nd DWORD
s_cmp_eq_u32 ttmp10, 0xff // 0xff == 255 = Literal constant32
s_cbranch_scc1 L_FIXUP_THREE_DWORD // 3 DWORD inst
// If otherwise, no special encodings. Default VOPD is 2 Dword
// Fall-thru if true, because this is a 2 DWORD inst
L_FIXUP_TWO_DWORD:
s_wait_kmcnt 0 // Wait for PC+2 and PC+3 to arrive in ttmp2 and ttmp3
s_mov_b32 ttmp15, ttmp2 // Move possible S_SET_VGPR_MSB into ttmp15
s_branch L_FIXUP_ONE_DWORD // Go to common logic that checks if it is S_SET_VGPR_MSB
L_FIXUP_THREE_DWORD:
s_wait_kmcnt 0 // Wait for PC+2 and PC+3 to arrive in ttmp2 and ttmp3
s_mov_b32 ttmp15, ttmp3 // Move possible S_SET_VGPR_MSB into ttmp15
s_branch L_FIXUP_ONE_DWORD // Go to common logic that checks if it is S_SET_VGPR_MSB
L_FIXUP_DONE:
s_wait_kmcnt 0 // Ensure load of ttmp2 and ttmp3 is done
end
#endif