From 671c7af6b96a854bf92221008fbac822005dff86 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 28 Nov 2024 15:29:08 -0800 Subject: [PATCH] opencl: remove small-alloc support and fix build errors for non-opencl platforms --- .github/workflows/build.yml | 4 +- ggml/CMakeLists.txt | 1 - ggml/include/ggml-alloc.h | 1 - ggml/src/CMakeLists.txt | 6 - ggml/src/ggml-alloc.c | 86 -------- ggml/src/ggml-opencl2/ggml-opencl2.cpp | 263 +------------------------ src/llama.cpp | 4 - tests/test-backend-ops.cpp | 21 +- 8 files changed, 4 insertions(+), 382 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 986049ff8..e99f493c2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -663,7 +663,7 @@ jobs: - build: 'msvc-arm64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' - build: 'llvm-arm64-opencl-adreno' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH=${{github.workspace}}/opencl-x64-release -DGGML_OPENCL=ON -DGGML_OPENCL_SMALL_ALLOC=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON -DGGML_OPENCL_EMBED_KERNELS=ON' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH=${{github.workspace}}/opencl-arm64-release -DGGML_OPENCL=ON -DGGML_OPENCL_SMALL_ALLOC=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON -DGGML_OPENCL_EMBED_KERNELS=ON' steps: - name: Clone @@ -716,7 +716,7 @@ jobs: -DBUILD_TESTING=OFF ` -DOPENCL_HEADERS_BUILD_TESTING=OFF ` -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` - -DCMAKE_INSTALL_PREFIX=${{github.workspace}}/opencl-x64-release + -DCMAKE_INSTALL_PREFIX=${{github.workspace}}/opencl-arm64-release cmake --build . --target install git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader cd OpenCL-ICD-Loader diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 73198bde8..3442142ad 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -180,7 +180,6 @@ set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING "ggml: sycl device architecture") option(GGML_OPENCL "ggml: use OpenCL" OFF) -option(GGML_OPENCL_SMALL_ALLOC "ggml: use small allocation for tensors" ON) option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF) option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 8db7f7460..23600eea9 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -69,7 +69,6 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i // Utils // Create a buffer and allocate all the tensors in a ggml_context GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); -GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft_for_weights(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); #ifdef __cplusplus diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 8fafdf759..272b043be 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -298,15 +298,9 @@ else () ggml_add_cpu_backend_variant_impl("") endif() -# TODO: This is intrusive. We intend to remove SMALL_ALLOC path once the we fully -# migrate to the non SMALL_ALLOC path. Also need to converge on the backend name -# so we don't need this name conversion. if (GGML_OPENCL) set(GGML_OPENCL2 ON) add_compile_definitions(GGML_USE_OPENCL) - if (GGML_OPENCL_SMALL_ALLOC) - add_compile_definitions(GGML_OPENCL_SMALL_ALLOC) - endif () else () set(GGML_OPENCL2 OFF) endif () diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index f2f39b125..2b2240be8 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -1033,92 +1033,6 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte return buffer; } -ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft_for_weights(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { -#ifndef GGML_OPENCL_SMALL_ALLOC - return ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); -#else - // Small allocation allocates a separate buffer for each tensor. Instead of - // collecting multiple tensors to allocate a large buffer, each tensor is - // allocated a buffer immediately. This is only supposed to be used for - // weights tensors (note that weights can be f32). - GGML_ASSERT(ggml_get_no_alloc(ctx) == true); - - size_t alignment = ggml_backend_buft_get_alignment(buft); - - ggml_backend_buffer_t * buffers = NULL; - size_t n_buffers = 0; - - struct ggml_tensor * first_view = NULL; - struct ggml_tensor * first = ggml_get_first_tensor(ctx); - for (struct ggml_tensor * t = first; t != NULL; t = ggml_get_next_tensor(ctx, t)) { - size_t this_size = 0; - if (t->data == NULL && t->view_src == NULL) { - // Tensor size must be properly padded. - this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment); - } - - // The allocation logic here has gone beyond intention in order to make - // `test-backend-ops` work. The very initial intention was to allocate - // memory for weights - each weight tensor gets its own buffer object. - // The original function should be used to allocate for intermediate tensors. - // There are usually no view tensors for weights; this is not true for - // intermediate tensors. However, in `test-backend-ops` there is no - // differetiation between weight tensors and intermediate tensors. - // This function is used for general allocation when small allocation is - // enabled in the test. This requires the function to also handle view - // tensors, which do no require actual allocation. In the original function, - // view tensors are allocated with other non-view tensors since view tensors - // sizes are 0. - // Here, we try to identify view tensors and allocate them with the next - // non-view tensor. View tensors cannot allocated (alone) but must be - // initialized (together with non-view tensors). - - // This is a view tensor of its size if 0. Record its location if it is the - // first one after a non-view tensor. If the next tensor is still a view, - // simply go to the next. We want to allocate all consecutive view tensors - // together with the next non-view tensor. - if (this_size == 0 && first_view == NULL) { - first_view = t; - continue; - } - - if (first_view) { - // This is a non-view tensor. If there are any view tensors before - // this non-view tensor, we want to allocate these view tensors and - // this non-view tensor together. - // The first tensor to allocate is the first view tensor. - first = first_view; - } else { - // Otherwise, allocate this non-view tensor immediately. - first = t; - } - - if (!alloc_tensor_range(ctx, first, ggml_get_next_tensor(ctx, t), buft, this_size, &buffers, &n_buffers)) { - return NULL; - } - - // Always reset first_view after a non-view tensor. - first_view = NULL; - } - - if (n_buffers == 0) { -#ifndef NDEBUG - fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__); -#endif - return NULL; - } - - ggml_backend_buffer_t buffer; - if (n_buffers == 1) { - buffer = buffers[0]; - } else { - buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers); - } - free(buffers); - return buffer; -#endif -} - ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) { return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend)); } diff --git a/ggml/src/ggml-opencl2/ggml-opencl2.cpp b/ggml/src/ggml-opencl2/ggml-opencl2.cpp index 64aa99cff..6df5625ad 100644 --- a/ggml/src/ggml-opencl2/ggml-opencl2.cpp +++ b/ggml/src/ggml-opencl2/ggml-opencl2.cpp @@ -505,10 +505,6 @@ static ggml_backend_opencl2_context * ggml_cl2_init(ggml_backend_dev_t dev) { fprintf(stderr, "ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); #endif // GGML_OPENCL_SOA_Q -#ifdef GGML_OPENCL_SMALL_ALLOC - fprintf(stderr, "ggml_opencl: allocating a separate buffer object for each tensor (GGML_OPENCL_SMALL_ALLOC)\n"); -#endif // GGML_OPENCL_SMALL_ALLOC - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS fprintf(stderr, "ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -828,23 +824,15 @@ struct ggml_tensor_extra_cl_q4_0 { } void reset() { - // When SMALL_ALLOC is not enabled, q and d are subbuffers into - // the bigger buffer allocated in ggml_backend_buffer. + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. // They must be properly released so that the original buffer can be // properly released to avoid memory leak. - // When SMALL_ALLOC is enabled, q and d point to the buffers in - // ggml_backend_opencl2_buffer_context. These buffers get released when - // the context is deleted, so there is no need to release them here. if (q != nullptr) { -#ifndef GGML_OPENCL_SMALL_ALLOC CL_CHECK(clReleaseMemObject(q)); -#endif q = nullptr; } if (d != nullptr) { -#ifndef GGML_OPENCL_SMALL_ALLOC CL_CHECK(clReleaseMemObject(d)); -#endif d = nullptr; } // Currently, q_img and d_img are only initialized when SMALL_ALLOC is @@ -1168,70 +1156,6 @@ static void ggml_backend_opencl2_buffer_init_tensor(ggml_backend_buffer_t buffer // there could be other places that need fix. tensor->extra = view_extra; } else { -#if defined(GGML_OPENCL_SOA_Q) && defined(GGML_OPENCL_SMALL_ALLOC) - // When small alloc is enabled with flattening, we create separate - // buffers for quants and scales (and potentially other components). - // These separate buffers are stored in ctx->buffer. To avoid double - // allocation, the buffer originally allocated is first released. - // Note that when this function is called, the buffer in context has - // been created byggml_backend_buft_alloc_buffer in alloc_tensor_range - // (llm_load_tensors -> ggml_backend_alloc_ctx_tensors_from_buft -> - // alloc_tensor_range -> ggml_backend_buft_alloc_buffer). - if (tensor->type == GGML_TYPE_Q4_0) { - CL_CHECK(clReleaseMemObject(ctx->buffer[0])); - size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; - GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); - - cl_int err; - ctx->buffer.resize(2); - ctx->buffer[0] = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_d, NULL, &err); - CL_CHECK(err); - - ctx->buffer[1] = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_q, NULL, &err); - CL_CHECK(err); - - // Populate images. - ctx->img.resize(2); - - cl_image_format fmt; - cl_image_desc desc; - - desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - desc.image_row_pitch = 0; - desc.image_slice_pitch = 0; - desc.num_mip_levels = 0; - desc.num_samples = 0; - - fmt.image_channel_data_type = CL_HALF_FLOAT; - fmt.image_channel_order = CL_R; - desc.buffer = ctx->buffer[0]; - desc.image_width = size_d / 2; - ctx->img[0] = clCreateImage(context, CL_MEM_READ_ONLY, &fmt, &desc, NULL, &err); - CL_CHECK(err); - - // Not checking if size_q is multiple of 16 - for legitimate Q4_0 - // the quants must have at least one block, so there must be at least - // 16 bytes. - fmt.image_channel_data_type = CL_FLOAT; - fmt.image_channel_order = CL_RGBA; - desc.buffer = ctx->buffer[1]; - desc.image_width = size_q / 16; - ctx->img[1] = clCreateImage(context, CL_MEM_READ_ONLY, &fmt, &desc, NULL, &err); - CL_CHECK(err); - - ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl2_alloc_temp_tensor_extra_q4_0(); - extra->d = ctx->buffer[0]; - extra->q = ctx->buffer[1]; - extra->d_img = ctx->img[0]; - extra->q_img = ctx->img[1]; - extra->size_d = size_d; - extra->size_q = size_q; - tensor->extra = extra; - } else -#endif { size_t offset = (char *)tensor->data - (char *)cl_ptr_base; @@ -1265,190 +1189,6 @@ static void ggml_backend_opencl2_buffer_set_tensor(ggml_backend_buffer_t buffer, // buffers for quantized bits and scales, which are then populated by the // conversion kernel. if (tensor->type == GGML_TYPE_Q4_0) { -#ifdef GGML_OPENCL_SMALL_ALLOC - // When small alloc is enabled with quant flattening, each tensor has - // been initialized with ggml_tensor_extra_cl_q4_0. - ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; - GGML_ASSERT(extra && "Tensors in OpenCL backend should have been allocated and initialized"); - - cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); - CL_CHECK(clEnqueueWriteBuffer( - queue, data_device, CL_TRUE, 0, - ggml_nbytes(tensor), data, 0, NULL, NULL)); - - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - - // The optimized kernels need weights in natural order, so unshuffle. - if (use_adreno_kernels(tensor)) { - kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; - } - #else - cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); - - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {64, 1, 1}; - - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clReleaseMemObject(data_device)); - - // transpose the weights and scales - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - // Only do transpose for large, non batched matrix - // TODO: use preallocated images instead of sub-buffer then image - if (use_adreno_kernels(tensor)) { - // <----------------------------------------------------------------------------------> // - // start transpose - // <----------------------------------------------------------------------------------> // - int M = tensor->ne[1]; // ne01 - int K = tensor->ne[0]; // ne00 - - // transpose is out of place, so we need to allocate transposed buffers - // <----------------------------------------------------------------------------------> // - // use sub_buffer of max buffer size instead - - size_t q_size_bytes = K * M / 8 * sizeof(float); - cl_buffer_region region; - region.origin = 0; - region.size = q_size_bytes; - cl_mem qT_d = clCreateSubBuffer( - backend_ctx->A_q_d_max, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err); - CL_CHECK(err); - - // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float); - size_t d_size_bytes = M * (K / 32) * 2; - region.origin = 0; - region.size = d_size_bytes; - cl_mem dT_d = clCreateSubBuffer( - backend_ctx->A_s_d_max, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err); - CL_CHECK(err); - - // <----------------------------------------------------------------------------------> // - - - // create images from the buffers - // <----------------------------------------------------------------------------------> // - cl_mem q_d_image1D; - cl_mem d_d_image1D; - cl_mem qT_d_image1D; - cl_mem dT_d_image1D; - - cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT }; - cl_image_desc img_desc_1d; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 8 / 4; - img_desc_1d.buffer = extra->q; - q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 8 / 4; - img_desc_1d.buffer = qT_d; - qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4 / 2; - img_desc_1d.buffer = extra->d; - d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4 / 2; - img_desc_1d.buffer = dT_d; - dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - // <----------------------------------------------------------------------------------> // - - // set up and call the transpose kernels - // <----------------------------------------------------------------------------------> // - // weights - int height_q = M / 8; - int width_q = K / 8 / 4; - kernel = backend_ctx->kernel_transpose_16; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); - - size_t local_size_q[3] = {4, 16, 1}; - size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // scales - int height_s = M / 8; - int width_s = K / 32 / 8; - - kernel = backend_ctx->kernel_transpose_16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); - - size_t local_size_s[3] = {4, 16, 1}; - size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // - - // copy transposed buffer contents to original buffers - // <----------------------------------------------------------------------------------> // - // weights - CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // scales - CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // - - // deallocate transpose buffers - // <----------------------------------------------------------------------------------> // - CL_CHECK(clReleaseMemObject(qT_d)); - CL_CHECK(clReleaseMemObject(dT_d)); - - // deallocate temporary images - CL_CHECK(clReleaseMemObject(q_d_image1D)); - CL_CHECK(clReleaseMemObject(d_d_image1D)); - CL_CHECK(clReleaseMemObject(qT_d_image1D)); - CL_CHECK(clReleaseMemObject(dT_d_image1D)); - // <----------------------------------------------------------------------------------> // - // end transpose - // <----------------------------------------------------------------------------------> // - } - #endif // GGML_OPENCL_USE_ADRENO_KERNELS - - return; -#else // GGML_OPENCL_SMALL_ALLOC // Tensors should have been preallocated, therefore they should // already have ggml_tensor_extra_cl as extra. ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -1677,7 +1417,6 @@ static void ggml_backend_opencl2_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; -#endif // GGML_OPENCL_SMALL_ALLOC } #endif // GGML_OPENCL_SOA_Q diff --git a/src/llama.cpp b/src/llama.cpp index eec2e2017..49ef5b78a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9296,11 +9296,7 @@ static bool llm_load_tensors( } } else { -#ifdef GGML_USE_OPENCL - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft_for_weights(ctx, buft); -#else // GGML_USE_OPENCL ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); -#endif if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7ff0bb7f5..0770e7cd6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -18,7 +18,6 @@ #include #include #include -#include #include #include @@ -466,19 +465,7 @@ struct test_case { // post-graph sentinel add_sentinel(ctx); - // allocate - // We need to use this function to properly support GGML_OPENCL_SMALL_ALLOC. - // In fact, `ggml_backend_alloc_ctx_tensors_from_buft_for_weights` is a - // bit misdenomer. It is initially created for allocating weights. But - // it can be used for allocating any tensors that needs small alloc. - // Something like `ggml_backend_alloc_ctx_tensors_from_buft_2` or - // `ggml_backend_alloc_ctx_tensors_from_buft_small_alloc` would be better. - // - // This is intrusive. We intend to remove SMALL_ALLOC path once the we fully - // migrate to the non SMALL_ALLOC path. - ggml_backend_buffer_t buf = ggml_backend_is_opencl2(backend1) == false ? ggml_backend_alloc_ctx_tensors(ctx, backend1) : - ggml_backend_alloc_ctx_tensors_from_buft_for_weights( - ctx, ggml_backend_get_default_buffer_type(backend1)); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1); if (buf == NULL) { printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); ggml_free(ctx); @@ -630,13 +617,7 @@ struct test_case { printf("%*s", last - len, ""); // allocate -#ifdef GGML_USE_OPENCL - ggml_backend_buffer_t buf = - ggml_backend_alloc_ctx_tensors_from_buft_for_weights( - ctx, ggml_backend_get_default_buffer_type(backend)); -#else ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); -#endif if (buf == NULL) { printf("failed to allocate tensors\n"); ggml_free(ctx);