From 4ff1046d75e64f0e556d8dcd930ea25c23eb8b18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Nov 2023 16:22:30 +0200 Subject: [PATCH 1/6] gguf : print error for GGUFv1 files (#3908) --- ggml.c | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml.c b/ggml.c index d5a49d8e4..605a27940 100644 --- a/ggml.c +++ b/ggml.c @@ -18884,6 +18884,13 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); + if (ctx->header.version == 1) { + fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + if (!ok) { fprintf(stderr, "%s: failed to read header\n", __func__); fclose(file); From d6069051de7165a4e06662c89257f5d2905bb156 Mon Sep 17 00:00:00 2001 From: Oleksii Maryshchenko Date: Thu, 2 Nov 2023 18:10:39 +0100 Subject: [PATCH 2/6] cuda : use CUDA memory pool with async memory allocation/deallocation when available (#3903) * Using cuda memory pools for async alloc/dealloc. * If cuda device doesnt support memory pool than use old implementation. * Removed redundant cublasSetStream --------- Co-authored-by: Oleksii Maryshchenko --- ggml-cuda.cu | 130 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 52 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e46295126..58b58f331 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -181,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cudaError_t err_ = (err); \ if (err_ != cudaSuccess) { \ - int id; \ - cudaGetDevice(&id); \ + int dev_id; \ + cudaGetDevice(&dev_id); \ fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ cudaGetErrorString(err_)); \ - fprintf(stderr, "current device: %d\n", id); \ + fprintf(stderr, "current device: %d\n", dev_id); \ exit(1); \ } \ } while (0) @@ -195,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cublasStatus_t err_ = (err); \ if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int id; \ - cudaGetDevice(&id); \ + int dev_id; \ + cudaGetDevice(&dev_id); \ fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ - fprintf(stderr, "current device: %d\n", id); \ + fprintf(stderr, "current device: %d\n", dev_id); \ exit(1); \ } \ } while (0) @@ -465,6 +465,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MAX_STREAMS 8 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; +static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr }; struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors @@ -5772,6 +5773,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } +static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) { + if (g_cudaMemPools[id] == nullptr) { + return ggml_cuda_pool_malloc(size, actual_size); + } + void *ptr; + CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream)); + *actual_size = size; + return ptr; +} + static void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; @@ -5790,6 +5801,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { } +static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) { + if (g_cudaMemPools[id] == nullptr) { + return ggml_cuda_pool_free(ptr, actual_size); + } + CUDA_CHECK(cudaFreeAsync(ptr, stream)); +} + void ggml_init_cublas() { static bool initialized = false; @@ -5844,6 +5862,13 @@ void ggml_init_cublas() { // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); + + // configure memory pool + cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id); + if (err == cudaSuccess) { + size_t treshold = UINT64_MAX; + CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold)); + } } // configure logging to stdout @@ -6437,7 +6462,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = row_diff*ne00; - src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); + src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream); to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); } const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; @@ -6448,13 +6473,12 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = src1_ncols*ne10; - src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); + src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; - - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); + size_t dst_f16_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; @@ -6472,14 +6496,15 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); - ggml_cuda_pool_free(dst_f16, dst_as); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_as_f16, src0_as); + if (dst_f16_as != 0) { + ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream); } + if (src0_as != 0) { + ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream); + } if (src1_as != 0) { - ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream); } } else { @@ -6489,7 +6514,7 @@ inline void ggml_cuda_op_mul_mat_cublas( if (src0->type != GGML_TYPE_F32) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); GGML_ASSERT(to_fp32_cuda != nullptr); - src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT + src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); } const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; @@ -6506,7 +6531,7 @@ inline void ggml_cuda_op_mul_mat_cublas( &beta, dst_dd_i, ldc)); if (src0_as != 0) { - ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); + ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream); } } @@ -6929,21 +6954,22 @@ static void ggml_cuda_op_mul_mat( src0_dd[id] = (char *) src0_extra->data_device[id]; } else { const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); - src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); + src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream); } if (src1_on_device && src1_is_contiguous) { src1_ddf[id] = (float *) src1_extra->data_device[id]; } else { - src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); + src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream); } if (convert_src1_to_q8_1) { - src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); + const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream); if (src1_on_device && src1_is_contiguous) { quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); - CUDA_CHECK(cudaGetLastError()); + // CUDA_CHECK(cudaGetLastError()); } } @@ -6951,7 +6977,7 @@ static void ggml_cuda_op_mul_mat( dst_dd[id] = (float *) dst_extra->data_device[id]; } else { const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); - dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); + dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream); } } @@ -7077,24 +7103,6 @@ static void ggml_cuda_op_mul_mat( } } - for (int64_t id = 0; id < g_device_count; ++id) { - CUDA_CHECK(ggml_cuda_set_device(id)); - - // free buffers again when done - if (src0_as[id] > 0) { - ggml_cuda_pool_free(src0_dd[id], src0_as[id]); - } - if (src1_asf[id] > 0) { - ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); - } - if (src1_asq[id] > 0) { - ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); - } - if (dst_as[id] > 0) { - ggml_cuda_pool_free(dst_dd[id], dst_as[id]); - } - } - // main device waits for all other devices to be finished if (split && g_device_count > 1) { int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; @@ -7112,6 +7120,21 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaDeviceSynchronize()); } + + for (int64_t id = 0; id < g_device_count; ++id) { + if (src0_as[id] > 0) { + ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]); + } + if (src1_asf[id] > 0) { + ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]); + } + if (src1_asq[id] > 0) { + ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]); + } + if (dst_as[id] > 0) { + ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]); + } + } } static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7298,11 +7321,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const GGML_ASSERT(to_fp16_cuda != nullptr); size_t src1_as = 0; - half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); + half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream); GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -7349,10 +7372,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const } else { // use cublasGemmBatchedEx const int ne23 = ne12*ne13; - - void ** ptrs_as = nullptr; + // allocate device memory for pointers size_t ptrs_s = 0; - ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s); + void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream); dim3 block_dims(ne13, ne12); k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( @@ -7365,7 +7387,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const dst->nb[2], dst->nb[3], r2, r3); CUDA_CHECK(cudaGetLastError()); - CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, @@ -7375,16 +7396,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ne23, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - ggml_cuda_pool_free(ptrs_as, ptrs_s); + // free device memory for pointers + if (ptrs_s != 0) { + ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream); + } } #endif const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); - - ggml_cuda_pool_free(src1_as_f16, src1_as); - ggml_cuda_pool_free(dst_f16, dst_as); + if (src1_as != 0) { + ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream); + } + if (dst_as != 0) { + ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream); + } } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From c7743fe1c1cbda5a886362aa371480360580fdf0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Nov 2023 20:32:11 +0200 Subject: [PATCH 3/6] cuda : fix const ptrs warning causing ROCm build issues (#3913) --- ggml-cuda.cu | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 58b58f331..06c28f565 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7248,7 +7248,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, - void ** ptrs, + const void ** ptrs_src, void ** ptrs_dst, int ne12, int ne13, int ne23, int nb02, int nb03, @@ -7265,9 +7265,9 @@ __global__ void k_compute_batched_ptrs( int i03 = i13 / r3; int i02 = i12 / r2; - ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03; - ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; - ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2; + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; } static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7372,14 +7372,20 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const } else { // use cublasGemmBatchedEx const int ne23 = ne12*ne13; - // allocate device memory for pointers - size_t ptrs_s = 0; - void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream); + + const void ** ptrs_src = nullptr; + void ** ptrs_dst = nullptr; + + size_t ptrs_src_s = 0; + size_t ptrs_dst_s = 0; + + ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream); + ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream); dim3 block_dims(ne13, ne12); k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( src0_as_f16, src1_as_f16, dst_f16, - ptrs_as, + ptrs_src, ptrs_dst, ne12, ne13, ne23, nb02, nb03, @@ -7390,15 +7396,18 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - &alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), - (const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), - &beta_f16, ( void ** ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01, + &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), + (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), + &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, ne23, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // free device memory for pointers - if (ptrs_s != 0) { - ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream); + + if (ptrs_src_s != 0) { + ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream); + } + if (ptrs_dst_s != 0) { + ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream); } } #endif From 224e7d5b14cbabab7ae45c64db2cfde979c8455d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Nov 2023 20:44:12 +0200 Subject: [PATCH 4/6] readme : add notice about #3912 --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index b56ecaec7..9c9e36ad0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ ![llama](https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png) -[![Actions Status](https://github.com/ggerganov/llama.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/llama.cpp/actions) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [Roadmap](https://github.com/users/ggerganov/projects/7) / [Project status](https://github.com/ggerganov/llama.cpp/discussions/3471) / [Manifesto](https://github.com/ggerganov/llama.cpp/discussions/205) / [ggml](https://github.com/ggerganov/ggml) @@ -11,8 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics -- LLaVA support: https://github.com/ggerganov/llama.cpp/pull/3436 -- ‼️ BPE tokenizer update: existing Falcon and Starcoder `.gguf` models will need to be reconverted: [#3252](https://github.com/ggerganov/llama.cpp/pull/3252) +- ⚠️ **Upcoming change that might break functionality. Help with testing is needed:** https://github.com/ggerganov/llama.cpp/pull/3912 ---- From 51b2fc11f7f605fff49725a4540e9a6ef7b51b70 Mon Sep 17 00:00:00 2001 From: Andrei Date: Thu, 2 Nov 2023 15:40:31 -0400 Subject: [PATCH 5/6] cmake : fix relative path to git submodule index (#3915) --- common/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0150114e3..ac594b2ca 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -11,7 +11,7 @@ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git") if(NOT IS_DIRECTORY "${GIT_DIR}") file(READ ${GIT_DIR} REAL_GIT_DIR_LINK) string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK}) - set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/${REAL_GIT_DIR}") + set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${REAL_GIT_DIR}") endif() set(GIT_INDEX "${GIT_DIR}/index") From 629f917cd6b96ba1274c49a8aab163b1b189229d Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:58:22 -0600 Subject: [PATCH 6/6] cuda : add ROCM aliases for CUDA pool stuff (#3918) --- ggml-cuda.cu | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 06c28f565..baf02df2b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -39,6 +39,10 @@ #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceGetMemPool hipDeviceGetMemPool +#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold +#define cudaMemPoolSetAttribute hipMemPoolSetAttribute +#define cudaMemPool_t hipMemPool_t #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -48,6 +52,7 @@ #define cudaEvent_t hipEvent_t #define cudaEventDestroy hipEventDestroy #define cudaFree hipFree +#define cudaFreeAsync hipFreeAsync #define cudaFreeHost hipHostFree #define cudaGetDevice hipGetDevice #define cudaGetDeviceCount hipGetDeviceCount @@ -55,6 +60,7 @@ #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError #define cudaMalloc hipMalloc +#define cudaMallocFromPoolAsync hipMallocFromPoolAsync #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMemcpy hipMemcpy #define cudaMemcpy2DAsync hipMemcpy2DAsync