ggml: fix ROCm build for cublasGemmBatchedEx reserve wrapper

Add missing cublasGemmAlgo_t to hipblasGemmAlgo_t type mapping and
cast away const qualifiers that hipblasGemmBatchedEx doesn't accept.
This commit is contained in:
Jesse Gross
2026-04-03 13:26:50 -07:00
parent c8e0878814
commit 3cd2b03a5e
3 changed files with 22 additions and 5 deletions

View File

@@ -229,7 +229,7 @@ diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 9fcb2f9fd..e800ee8f6 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -37,6 +37,62 @@
@@ -37,6 +37,64 @@
#include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP)
@@ -273,8 +273,10 @@ index 9fcb2f9fd..e800ee8f6 100644
+ cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
+ if (!reserving_graph) {
+ return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
+ alpha, Aarray, Atype, lda, Barray, Btype, ldb,
+ beta, Carray, Ctype, ldc, batchCount, computeType, algo);
+ alpha, const_cast<const void **>(Aarray), Atype, lda,
+ const_cast<const void **>(Barray), Btype, ldb,
+ beta, const_cast<void **>(Carray), Ctype, ldc,
+ batchCount, computeType, algo);
+ } else {
+ return CUBLAS_STATUS_SUCCESS;
+ }
@@ -343,6 +345,18 @@ index 9fcb2f9fd..e800ee8f6 100644
};
struct ggml_cuda_mm_fusion_args_host {
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index d89e35a8e..fe4b5349f 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -39,6 +39,7 @@
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
+#define cublasGemmAlgo_t hipblasGemmAlgo_t
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 25548629d..eeaae3fe4 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu

View File

@@ -77,8 +77,10 @@ static cublasStatus_t cublasGemmBatchedExReserve(
cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
if (!reserving_graph) {
return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
alpha, Aarray, Atype, lda, Barray, Btype, ldb,
beta, Carray, Ctype, ldc, batchCount, computeType, algo);
alpha, const_cast<const void **>(Aarray), Atype, lda,
const_cast<const void **>(Barray), Btype, ldb,
beta, const_cast<void **>(Carray), Ctype, ldc,
batchCount, computeType, algo);
} else {
return CUBLAS_STATUS_SUCCESS;
}

View File

@@ -39,6 +39,7 @@
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmAlgo_t hipblasGemmAlgo_t
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t