mirror of
https://github.com/ollama/ollama.git
synced 2026-04-18 09:03:35 -04:00
ggml: fix ROCm build for cublasGemmBatchedEx reserve wrapper
Add missing cublasGemmAlgo_t to hipblasGemmAlgo_t type mapping and cast away const qualifiers that hipblasGemmBatchedEx doesn't accept.
This commit is contained in:
@@ -229,7 +229,7 @@ diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
|
||||
index 9fcb2f9fd..e800ee8f6 100644
|
||||
--- a/ggml/src/ggml-cuda/common.cuh
|
||||
+++ b/ggml/src/ggml-cuda/common.cuh
|
||||
@@ -37,6 +37,62 @@
|
||||
@@ -37,6 +37,64 @@
|
||||
#include "vendors/cuda.h"
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
|
||||
@@ -273,8 +273,10 @@ index 9fcb2f9fd..e800ee8f6 100644
|
||||
+ cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
|
||||
+ if (!reserving_graph) {
|
||||
+ return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
|
||||
+ alpha, Aarray, Atype, lda, Barray, Btype, ldb,
|
||||
+ beta, Carray, Ctype, ldc, batchCount, computeType, algo);
|
||||
+ alpha, const_cast<const void **>(Aarray), Atype, lda,
|
||||
+ const_cast<const void **>(Barray), Btype, ldb,
|
||||
+ beta, const_cast<void **>(Carray), Ctype, ldc,
|
||||
+ batchCount, computeType, algo);
|
||||
+ } else {
|
||||
+ return CUBLAS_STATUS_SUCCESS;
|
||||
+ }
|
||||
@@ -343,6 +345,18 @@ index 9fcb2f9fd..e800ee8f6 100644
|
||||
};
|
||||
|
||||
struct ggml_cuda_mm_fusion_args_host {
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index d89e35a8e..fe4b5349f 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -39,6 +39,7 @@
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
+#define cublasGemmAlgo_t hipblasGemmAlgo_t
|
||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 25548629d..eeaae3fe4 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
|
||||
@@ -77,8 +77,10 @@ static cublasStatus_t cublasGemmBatchedExReserve(
|
||||
cublasComputeType_t computeType, cublasGemmAlgo_t algo) {
|
||||
if (!reserving_graph) {
|
||||
return cublasGemmBatchedEx(handle, transa, transb, m, n, k,
|
||||
alpha, Aarray, Atype, lda, Barray, Btype, ldb,
|
||||
beta, Carray, Ctype, ldc, batchCount, computeType, algo);
|
||||
alpha, const_cast<const void **>(Aarray), Atype, lda,
|
||||
const_cast<const void **>(Barray), Btype, ldb,
|
||||
beta, const_cast<void **>(Carray), Ctype, ldc,
|
||||
batchCount, computeType, algo);
|
||||
} else {
|
||||
return CUBLAS_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasGemmAlgo_t hipblasGemmAlgo_t
|
||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
|
||||
Reference in New Issue
Block a user