diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 9788297cc..e23d345cd 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -890,6 +890,7 @@ func (f GGML) FlashAttention() bool { return slices.Contains([]string{ "bert", "gemma3", + "gemma4", "glm4moelite", "glmocr", "gptoss", "gpt-oss", diff --git a/llama/patches/0020-ggml-No-alloc-mode.patch b/llama/patches/0020-ggml-No-alloc-mode.patch index 92953f611..0070e46c4 100644 --- a/llama/patches/0020-ggml-No-alloc-mode.patch +++ b/llama/patches/0020-ggml-No-alloc-mode.patch @@ -8,12 +8,13 @@ it does not allocate memory for tensors or operations. This can be used for calculating memory requirements. Tensors and graphs must be recreated with no-alloc set to false before loading data. --- - ggml/include/ggml-backend.h | 1 + - ggml/src/ggml-backend-impl.h | 16 +++ - ggml/src/ggml-backend.cpp | 75 ++++++++++- - ggml/src/ggml-cuda/common.cuh | 83 +++++++++++- - ggml/src/ggml-cuda/ggml-cuda.cu | 224 ++++++++++++++++++++++++++------ - 5 files changed, 354 insertions(+), 45 deletions(-) + ggml/include/ggml-backend.h | 1 + + ggml/src/ggml-backend-impl.h | 16 +++ + ggml/src/ggml-backend.cpp | 75 ++++++++++- + ggml/src/ggml-cuda/common.cuh | 85 +++++++++++- + ggml/src/ggml-cuda/ggml-cuda.cu | 224 +++++++++++++++++++++++++------ + ggml/src/ggml-cuda/vendors/hip.h | 1 + + 6 files changed, 357 insertions(+), 45 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 93c95602d..dbbb61d9c 100644 @@ -226,7 +227,7 @@ index 498186a7c..7746e8b92 100644 void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh -index 9fcb2f9fd..e800ee8f6 100644 +index 9fcb2f9fd..79673ac2e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -37,6 +37,64 @@ @@ -294,7 +295,7 @@ index 9fcb2f9fd..e800ee8f6 100644 #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) -@@ -941,6 +976,9 @@ struct ggml_cuda_pool { +@@ -941,6 +999,9 @@ struct ggml_cuda_pool { virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void free(void * ptr, size_t size) = 0; @@ -304,7 +305,7 @@ index 9fcb2f9fd..e800ee8f6 100644 }; template -@@ -1232,11 +1270,15 @@ struct ggml_backend_cuda_context { +@@ -1232,11 +1293,15 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; @@ -322,7 +323,7 @@ index 9fcb2f9fd..e800ee8f6 100644 } return *pools[device][curr_stream_no]; } -@@ -1244,6 +1286,22 @@ struct ggml_backend_cuda_context { +@@ -1244,6 +1309,22 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } @@ -345,18 +346,6 @@ index 9fcb2f9fd..e800ee8f6 100644 }; struct ggml_cuda_mm_fusion_args_host { -diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index d89e35a8e..fe4b5349f 100644 ---- a/ggml/src/ggml-cuda/vendors/hip.h -+++ b/ggml/src/ggml-cuda/vendors/hip.h -@@ -39,6 +39,7 @@ - #define cublasCreate hipblasCreate - #define cublasDestroy hipblasDestroy - #define cublasGemmEx hipblasGemmEx -+#define cublasGemmAlgo_t hipblasGemmAlgo_t - #define cublasGemmBatchedEx hipblasGemmBatchedEx - #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx - #define cublasHandle_t hipblasHandle_t diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 25548629d..eeaae3fe4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu @@ -717,3 +706,15 @@ index 25548629d..eeaae3fe4 100644 }; static ggml_guid_t ggml_backend_cuda_guid() { +diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h +index 951a88d56..b9627eb8b 100644 +--- a/ggml/src/ggml-cuda/vendors/hip.h ++++ b/ggml/src/ggml-cuda/vendors/hip.h +@@ -37,6 +37,7 @@ + #define cublasCreate hipblasCreate + #define cublasDestroy hipblasDestroy + #define cublasGemmEx hipblasGemmEx ++#define cublasGemmAlgo_t hipblasGemmAlgo_t + #define cublasGemmBatchedEx hipblasGemmBatchedEx + #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx + #define cublasHandle_t hipblasHandle_t diff --git a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch index 3197f94e8..a1321e727 100644 --- a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch +++ b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch @@ -110,10 +110,10 @@ index eeaae3fe4..6852d2e20 100644 // backend reg diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index 951a88d56..4e162258d 100644 +index b9627eb8b..e3310bdf3 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h -@@ -49,6 +49,7 @@ +@@ -50,6 +50,7 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t diff --git a/llama/patches/0024-GPU-discovery-enhancements.patch b/llama/patches/0024-GPU-discovery-enhancements.patch index 6e4ef2394..e989ad89b 100644 --- a/llama/patches/0024-GPU-discovery-enhancements.patch +++ b/llama/patches/0024-GPU-discovery-enhancements.patch @@ -183,7 +183,7 @@ index 6852d2e20..334a30135 100644 /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index 4e162258d..d89e35a8e 100644 +index e3310bdf3..fe4b5349f 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -5,6 +5,8 @@ @@ -195,7 +195,7 @@ index 4e162258d..d89e35a8e 100644 #if defined(GGML_HIP_ROCWMMA_FATTN) #include -@@ -51,6 +53,7 @@ +@@ -52,6 +54,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize diff --git a/llama/patches/0036-backport-kernels-for-gemma4.patch b/llama/patches/0036-backport-kernels-for-gemma4.patch new file mode 100644 index 000000000..86123ccc5 --- /dev/null +++ b/llama/patches/0036-backport-kernels-for-gemma4.patch @@ -0,0 +1,416 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Mon, 6 Apr 2026 19:56:36 -0700 +Subject: [PATCH] backport kernels for gemma4 + +--- + ggml/src/ggml-cuda/fattn-mma-f16.cuh | 26 ++++++++++++- + ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ + ggml/src/ggml-cuda/fattn-tile.cuh | 37 +++++++++++++++---- + ggml/src/ggml-cuda/fattn.cu | 11 +++++- + ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 1 + + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 1 + + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 1 + + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 1 + + ggml/src/ggml-metal/ggml-metal-device.m | 1 + + ggml/src/ggml-metal/ggml-metal.metal | 19 ++++++++++ + 14 files changed, 96 insertions(+), 10 deletions(-) + +diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh +index 3dea2205e..b4282d3cb 100644 +--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh ++++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh +@@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); ++ + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); +@@ -81,6 +86,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); ++ + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); +@@ -91,6 +101,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co + } + + static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false); ++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false); ++ + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); +@@ -1361,7 +1376,7 @@ static __global__ void flash_attn_ext_f16( + #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) + + // Skip unused kernel variants for faster compilation: +- if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { ++ if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { + NO_DEVICE_CODE; + return; + } +@@ -1600,6 +1615,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) + DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) + DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) + ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); ++extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); ++ + // The number of viable configurations for Deepseek is very limited: + extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); + extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu +index 3fcb09b7a..25b16e83c 100644 +--- a/ggml/src/ggml-cuda/fattn-tile.cu ++++ b/ggml/src/ggml-cuda/fattn-tile.cu +@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; ++ case 512: { ++ GGML_ASSERT(V->ne[0] == K->ne[0]); ++ ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); ++ } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); +diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh +index 371be7442..43906bd73 100644 +--- a/ggml/src/ggml-cuda/fattn-tile.cuh ++++ b/ggml/src/ggml-cuda/fattn-tile.cuh +@@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) ++ + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) +@@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) ++ + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) +@@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64) ++ + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) +@@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) ++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64) ++ + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) +@@ -767,7 +785,7 @@ static __global__ void flash_attn_tile( + #ifdef GGML_USE_WMMA_FATTN + (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || + #endif // GGML_USE_WMMA_FATTN +- (use_logit_softcap && !(DV == 128 || DV == 256)) ++ (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512)) + ) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, +@@ -1190,7 +1208,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm + const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + +- if constexpr (DV == 512) { ++ if constexpr (DKQ == 576) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; +@@ -1205,7 +1223,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm + } + } + +- if constexpr (DV <= 256) { ++ if constexpr (DKQ <= 512) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; +@@ -1216,13 +1234,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm + return; + } + +- if (use_gqa_opt && gqa_ratio % 2 == 0) { +- launch_fattn_tile_switch_ncols1(ctx, dst); ++ if constexpr (DV <= 256) { ++ if (use_gqa_opt && gqa_ratio % 2 == 0) { ++ launch_fattn_tile_switch_ncols1(ctx, dst); ++ return; ++ } ++ ++ launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } +- +- launch_fattn_tile_switch_ncols1(ctx, dst); +- return; + } + GGML_ABORT("fatal error"); + } +@@ -1257,4 +1277,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96); + extern DECL_FATTN_TILE_CASE(112, 112); + extern DECL_FATTN_TILE_CASE(128, 128); + extern DECL_FATTN_TILE_CASE(256, 256); ++extern DECL_FATTN_TILE_CASE(512, 512); + extern DECL_FATTN_TILE_CASE(576, 512); +diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu +index 1693479cb..0b7f9ae07 100644 +--- a/ggml/src/ggml-cuda/fattn.cu ++++ b/ggml/src/ggml-cuda/fattn.cu +@@ -110,6 +110,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + break; ++ case 512: ++ GGML_ASSERT(V->ne[0] == 512); ++ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); ++ break; + case 576: { + // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + GGML_ASSERT(V->ne[0] == 512); +@@ -251,6 +255,11 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const + return BEST_FATTN_KERNEL_NONE; + } + break; ++ case 512: ++ if (!gqa_opt_applies) { ++ return BEST_FATTN_KERNEL_NONE; ++ } ++ break; + case 576: + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; +@@ -334,7 +343,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const + } + + // Use the WMMA kernel if possible: +- if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { ++ if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (can_use_vector_kernel && Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +index dc1682902..22d383173 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); + DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); + DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); + DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); ++DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +index 517993cb0..d2415bfa9 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); ++DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); + DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +index 97b19c67a..8eec1d74e 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); ++DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); + DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +index 163b1d939..84b674cd0 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); + DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); + DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); + DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); ++DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +index 989626dfa..3475dfea0 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); ++DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); + DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +index bad296b41..5906398db 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); + DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); + DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); + DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); ++DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +index 173de7aac..684cd25ce 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); + DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); + DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); + DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); ++DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); + DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +index 680a13ca6..4bc60d62f 100644 +--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu ++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); + DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); + DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); + DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); ++DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); +diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m +index 4e5acfbe5..11457f2b1 100644 +--- a/ggml/src/ggml-metal/ggml-metal-device.m ++++ b/ggml/src/ggml-metal/ggml-metal-device.m +@@ -1083,6 +1083,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te + op->src[0]->ne[0] != 128 && + op->src[0]->ne[0] != 192 && + op->src[0]->ne[0] != 256 && ++ op->src[0]->ne[0] != 512 && + op->src[0]->ne[0] != 576) { + return false; + } +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index 4f338aa13..8be0c1f0c 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -6276,6 +6276,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + + template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +@@ -6290,6 +6291,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + + #if defined(GGML_METAL_HAS_BF16) +@@ -6305,6 +6307,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + #endif + +@@ -6320,6 +6323,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + + template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +@@ -6334,6 +6338,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + + template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +@@ -6348,6 +6353,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + + template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +@@ -6362,6 +6368,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + + template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +@@ -6376,6 +6383,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at + template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; ++template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + + #undef FA_TYPES +@@ -6990,6 +6998,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas + template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + ++template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++#if defined(GGML_METAL_HAS_BF16) ++template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++#endif ++template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; ++ + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + #if defined(GGML_METAL_HAS_BF16) diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3dea2205e..b4282d3cb 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); @@ -81,6 +86,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); @@ -91,6 +101,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); @@ -1361,7 +1376,7 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { NO_DEVICE_CODE; return; } @@ -1600,6 +1615,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); + // The number of viable configurations for Deepseek is very limited: extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a..25b16e83c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh index 371be7442..43906bd73 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) @@ -767,7 +785,7 @@ static __global__ void flash_attn_tile( #ifdef GGML_USE_WMMA_FATTN (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || #endif // GGML_USE_WMMA_FATTN - (use_logit_softcap && !(DV == 128 || DV == 256)) + (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512)) ) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1190,7 +1208,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DV == 512) { + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1205,7 +1223,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DV <= 256) { + if constexpr (DKQ <= 512) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1216,13 +1234,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); return; } - - launch_fattn_tile_switch_ncols1(ctx, dst); - return; } GGML_ABORT("fatal error"); } @@ -1257,4 +1277,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu index 1693479cb..0b7f9ae07 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu @@ -110,6 +110,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 512: + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); + break; case 576: { // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); @@ -251,6 +255,11 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 512: + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 576: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; @@ -334,7 +343,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index dc1682902..22d383173 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 517993cb0..d2415bfa9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 97b19c67a..8eec1d74e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 163b1d939..84b674cd0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 989626dfa..3475dfea0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index bad296b41..5906398db 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 173de7aac..684cd25ce 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 680a13ca6..4bc60d62f 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m index 4e5acfbe5..11457f2b1 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -1083,6 +1083,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 512 && op->src[0]->ne[0] != 576) { return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index df243edcb..38ce1fdc6 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -9098,6 +9098,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -9112,6 +9113,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -9127,6 +9129,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -9142,6 +9145,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -9156,6 +9160,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -9170,6 +9175,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -9184,6 +9190,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -9198,6 +9205,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -9812,6 +9820,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index 4f338aa13..8be0c1f0c 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -6276,6 +6276,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6290,6 +6291,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -6305,6 +6307,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -6320,6 +6323,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6334,6 +6338,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6348,6 +6353,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6362,6 +6368,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6376,6 +6383,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -6990,6 +6998,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16)