Apply suggestions from code review
These changes plus: ```c++ #define cublasGemmBatchedEx hipblasGemmBatchedEx ``` are needed to compile with ROCM. I haven't done performance testing, but it seems to work. I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up.
This commit is contained in:
parent
c13fcfbfc0
commit
878aa4f209
1 changed files with 6 additions and 6 deletions
12
ggml-cuda.cu
12
ggml-cuda.cu
|
@ -7154,9 +7154,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate device memory for pointers
|
// allocate device memory for pointers
|
||||||
void ** src0_ptrs_as = nullptr;
|
const void ** src0_ptrs_as = nullptr;
|
||||||
void ** src1_ptrs_as = nullptr;
|
const void ** src1_ptrs_as = nullptr;
|
||||||
void ** dst_ptrs_as = nullptr;
|
void ** dst_ptrs_as = nullptr;
|
||||||
|
|
||||||
CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *)));
|
CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *)));
|
||||||
CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *)));
|
CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *)));
|
||||||
|
@ -7170,9 +7170,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half),
|
&alpha_f16, (const void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half),
|
||||||
(void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float),
|
(const void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float),
|
||||||
&beta_f16, (void **) dst_ptrs_as, CUDA_R_16F, ne01,
|
&beta_f16, ( void **) dst_ptrs_as, CUDA_R_16F, ne01,
|
||||||
ne23,
|
ne23,
|
||||||
CUBLAS_COMPUTE_16F,
|
CUBLAS_COMPUTE_16F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue