opencl: remove small-alloc support and fix build errors for non-opencl platforms
This commit is contained in:
parent
8ad0bb30df
commit
671c7af6b9
8 changed files with 4 additions and 382 deletions
4
.github/workflows/build.yml
vendored
4
.github/workflows/build.yml
vendored
|
@ -663,7 +663,7 @@ jobs:
|
|||
- build: 'msvc-arm64'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON'
|
||||
- build: 'llvm-arm64-opencl-adreno'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH=${{github.workspace}}/opencl-x64-release -DGGML_OPENCL=ON -DGGML_OPENCL_SMALL_ALLOC=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON -DGGML_OPENCL_EMBED_KERNELS=ON'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH=${{github.workspace}}/opencl-arm64-release -DGGML_OPENCL=ON -DGGML_OPENCL_SMALL_ALLOC=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON -DGGML_OPENCL_EMBED_KERNELS=ON'
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
@ -716,7 +716,7 @@ jobs:
|
|||
-DBUILD_TESTING=OFF `
|
||||
-DOPENCL_HEADERS_BUILD_TESTING=OFF `
|
||||
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF `
|
||||
-DCMAKE_INSTALL_PREFIX=${{github.workspace}}/opencl-x64-release
|
||||
-DCMAKE_INSTALL_PREFIX=${{github.workspace}}/opencl-arm64-release
|
||||
cmake --build . --target install
|
||||
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader
|
||||
cd OpenCL-ICD-Loader
|
||||
|
|
|
@ -180,7 +180,6 @@ set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
|||
"ggml: sycl device architecture")
|
||||
|
||||
option(GGML_OPENCL "ggml: use OpenCL" OFF)
|
||||
option(GGML_OPENCL_SMALL_ALLOC "ggml: use small allocation for tensors" ON)
|
||||
option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF)
|
||||
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
|
||||
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
|
||||
|
|
|
@ -69,7 +69,6 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i
|
|||
// Utils
|
||||
// Create a buffer and allocate all the tensors in a ggml_context
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft_for_weights(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -298,15 +298,9 @@ else ()
|
|||
ggml_add_cpu_backend_variant_impl("")
|
||||
endif()
|
||||
|
||||
# TODO: This is intrusive. We intend to remove SMALL_ALLOC path once the we fully
|
||||
# migrate to the non SMALL_ALLOC path. Also need to converge on the backend name
|
||||
# so we don't need this name conversion.
|
||||
if (GGML_OPENCL)
|
||||
set(GGML_OPENCL2 ON)
|
||||
add_compile_definitions(GGML_USE_OPENCL)
|
||||
if (GGML_OPENCL_SMALL_ALLOC)
|
||||
add_compile_definitions(GGML_OPENCL_SMALL_ALLOC)
|
||||
endif ()
|
||||
else ()
|
||||
set(GGML_OPENCL2 OFF)
|
||||
endif ()
|
||||
|
|
|
@ -1033,92 +1033,6 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
|
|||
return buffer;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft_for_weights(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
|
||||
#ifndef GGML_OPENCL_SMALL_ALLOC
|
||||
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
#else
|
||||
// Small allocation allocates a separate buffer for each tensor. Instead of
|
||||
// collecting multiple tensors to allocate a large buffer, each tensor is
|
||||
// allocated a buffer immediately. This is only supposed to be used for
|
||||
// weights tensors (note that weights can be f32).
|
||||
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
|
||||
|
||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||
|
||||
ggml_backend_buffer_t * buffers = NULL;
|
||||
size_t n_buffers = 0;
|
||||
|
||||
struct ggml_tensor * first_view = NULL;
|
||||
struct ggml_tensor * first = ggml_get_first_tensor(ctx);
|
||||
for (struct ggml_tensor * t = first; t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
size_t this_size = 0;
|
||||
if (t->data == NULL && t->view_src == NULL) {
|
||||
// Tensor size must be properly padded.
|
||||
this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
|
||||
}
|
||||
|
||||
// The allocation logic here has gone beyond intention in order to make
|
||||
// `test-backend-ops` work. The very initial intention was to allocate
|
||||
// memory for weights - each weight tensor gets its own buffer object.
|
||||
// The original function should be used to allocate for intermediate tensors.
|
||||
// There are usually no view tensors for weights; this is not true for
|
||||
// intermediate tensors. However, in `test-backend-ops` there is no
|
||||
// differetiation between weight tensors and intermediate tensors.
|
||||
// This function is used for general allocation when small allocation is
|
||||
// enabled in the test. This requires the function to also handle view
|
||||
// tensors, which do no require actual allocation. In the original function,
|
||||
// view tensors are allocated with other non-view tensors since view tensors
|
||||
// sizes are 0.
|
||||
// Here, we try to identify view tensors and allocate them with the next
|
||||
// non-view tensor. View tensors cannot allocated (alone) but must be
|
||||
// initialized (together with non-view tensors).
|
||||
|
||||
// This is a view tensor of its size if 0. Record its location if it is the
|
||||
// first one after a non-view tensor. If the next tensor is still a view,
|
||||
// simply go to the next. We want to allocate all consecutive view tensors
|
||||
// together with the next non-view tensor.
|
||||
if (this_size == 0 && first_view == NULL) {
|
||||
first_view = t;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (first_view) {
|
||||
// This is a non-view tensor. If there are any view tensors before
|
||||
// this non-view tensor, we want to allocate these view tensors and
|
||||
// this non-view tensor together.
|
||||
// The first tensor to allocate is the first view tensor.
|
||||
first = first_view;
|
||||
} else {
|
||||
// Otherwise, allocate this non-view tensor immediately.
|
||||
first = t;
|
||||
}
|
||||
|
||||
if (!alloc_tensor_range(ctx, first, ggml_get_next_tensor(ctx, t), buft, this_size, &buffers, &n_buffers)) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Always reset first_view after a non-view tensor.
|
||||
first_view = NULL;
|
||||
}
|
||||
|
||||
if (n_buffers == 0) {
|
||||
#ifndef NDEBUG
|
||||
fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
if (n_buffers == 1) {
|
||||
buffer = buffers[0];
|
||||
} else {
|
||||
buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);
|
||||
}
|
||||
free(buffers);
|
||||
return buffer;
|
||||
#endif
|
||||
}
|
||||
|
||||
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
|
||||
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
|
||||
}
|
||||
|
|
|
@ -505,10 +505,6 @@ static ggml_backend_opencl2_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
|||
fprintf(stderr, "ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n");
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
|
||||
#ifdef GGML_OPENCL_SMALL_ALLOC
|
||||
fprintf(stderr, "ggml_opencl: allocating a separate buffer object for each tensor (GGML_OPENCL_SMALL_ALLOC)\n");
|
||||
#endif // GGML_OPENCL_SMALL_ALLOC
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
fprintf(stderr, "ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n");
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
@ -828,23 +824,15 @@ struct ggml_tensor_extra_cl_q4_0 {
|
|||
}
|
||||
|
||||
void reset() {
|
||||
// When SMALL_ALLOC is not enabled, q and d are subbuffers into
|
||||
// the bigger buffer allocated in ggml_backend_buffer.
|
||||
// q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
|
||||
// They must be properly released so that the original buffer can be
|
||||
// properly released to avoid memory leak.
|
||||
// When SMALL_ALLOC is enabled, q and d point to the buffers in
|
||||
// ggml_backend_opencl2_buffer_context. These buffers get released when
|
||||
// the context is deleted, so there is no need to release them here.
|
||||
if (q != nullptr) {
|
||||
#ifndef GGML_OPENCL_SMALL_ALLOC
|
||||
CL_CHECK(clReleaseMemObject(q));
|
||||
#endif
|
||||
q = nullptr;
|
||||
}
|
||||
if (d != nullptr) {
|
||||
#ifndef GGML_OPENCL_SMALL_ALLOC
|
||||
CL_CHECK(clReleaseMemObject(d));
|
||||
#endif
|
||||
d = nullptr;
|
||||
}
|
||||
// Currently, q_img and d_img are only initialized when SMALL_ALLOC is
|
||||
|
@ -1168,70 +1156,6 @@ static void ggml_backend_opencl2_buffer_init_tensor(ggml_backend_buffer_t buffer
|
|||
// there could be other places that need fix.
|
||||
tensor->extra = view_extra;
|
||||
} else {
|
||||
#if defined(GGML_OPENCL_SOA_Q) && defined(GGML_OPENCL_SMALL_ALLOC)
|
||||
// When small alloc is enabled with flattening, we create separate
|
||||
// buffers for quants and scales (and potentially other components).
|
||||
// These separate buffers are stored in ctx->buffer. To avoid double
|
||||
// allocation, the buffer originally allocated is first released.
|
||||
// Note that when this function is called, the buffer in context has
|
||||
// been created byggml_backend_buft_alloc_buffer in alloc_tensor_range
|
||||
// (llm_load_tensors -> ggml_backend_alloc_ctx_tensors_from_buft ->
|
||||
// alloc_tensor_range -> ggml_backend_buft_alloc_buffer).
|
||||
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||
CL_CHECK(clReleaseMemObject(ctx->buffer[0]));
|
||||
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
|
||||
GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
|
||||
|
||||
cl_int err;
|
||||
ctx->buffer.resize(2);
|
||||
ctx->buffer[0] = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
size_d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
ctx->buffer[1] = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
size_q, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
// Populate images.
|
||||
ctx->img.resize(2);
|
||||
|
||||
cl_image_format fmt;
|
||||
cl_image_desc desc;
|
||||
|
||||
desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
desc.image_row_pitch = 0;
|
||||
desc.image_slice_pitch = 0;
|
||||
desc.num_mip_levels = 0;
|
||||
desc.num_samples = 0;
|
||||
|
||||
fmt.image_channel_data_type = CL_HALF_FLOAT;
|
||||
fmt.image_channel_order = CL_R;
|
||||
desc.buffer = ctx->buffer[0];
|
||||
desc.image_width = size_d / 2;
|
||||
ctx->img[0] = clCreateImage(context, CL_MEM_READ_ONLY, &fmt, &desc, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
// Not checking if size_q is multiple of 16 - for legitimate Q4_0
|
||||
// the quants must have at least one block, so there must be at least
|
||||
// 16 bytes.
|
||||
fmt.image_channel_data_type = CL_FLOAT;
|
||||
fmt.image_channel_order = CL_RGBA;
|
||||
desc.buffer = ctx->buffer[1];
|
||||
desc.image_width = size_q / 16;
|
||||
ctx->img[1] = clCreateImage(context, CL_MEM_READ_ONLY, &fmt, &desc, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl2_alloc_temp_tensor_extra_q4_0();
|
||||
extra->d = ctx->buffer[0];
|
||||
extra->q = ctx->buffer[1];
|
||||
extra->d_img = ctx->img[0];
|
||||
extra->q_img = ctx->img[1];
|
||||
extra->size_d = size_d;
|
||||
extra->size_q = size_q;
|
||||
tensor->extra = extra;
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
size_t offset = (char *)tensor->data - (char *)cl_ptr_base;
|
||||
|
||||
|
@ -1265,190 +1189,6 @@ static void ggml_backend_opencl2_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
// buffers for quantized bits and scales, which are then populated by the
|
||||
// conversion kernel.
|
||||
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||
#ifdef GGML_OPENCL_SMALL_ALLOC
|
||||
// When small alloc is enabled with quant flattening, each tensor has
|
||||
// been initialized with ggml_tensor_extra_cl_q4_0.
|
||||
ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra;
|
||||
GGML_ASSERT(extra && "Tensors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
CL_CHECK(clEnqueueWriteBuffer(
|
||||
queue, data_device, CL_TRUE, 0,
|
||||
ggml_nbytes(tensor), data, 0, NULL, NULL));
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
|
||||
|
||||
// The optimized kernels need weights in natural order, so unshuffle.
|
||||
if (use_adreno_kernels(tensor)) {
|
||||
kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle;
|
||||
}
|
||||
#else
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
// transpose the weights and scales
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
// Only do transpose for large, non batched matrix
|
||||
// TODO: use preallocated images instead of sub-buffer then image
|
||||
if (use_adreno_kernels(tensor)) {
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
// start transpose
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
int M = tensor->ne[1]; // ne01
|
||||
int K = tensor->ne[0]; // ne00
|
||||
|
||||
// transpose is out of place, so we need to allocate transposed buffers
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
// use sub_buffer of max buffer size instead
|
||||
|
||||
size_t q_size_bytes = K * M / 8 * sizeof(float);
|
||||
cl_buffer_region region;
|
||||
region.origin = 0;
|
||||
region.size = q_size_bytes;
|
||||
cl_mem qT_d = clCreateSubBuffer(
|
||||
backend_ctx->A_q_d_max,
|
||||
0,
|
||||
CL_BUFFER_CREATE_TYPE_REGION,
|
||||
®ion,
|
||||
&err);
|
||||
// cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
// size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float);
|
||||
size_t d_size_bytes = M * (K / 32) * 2;
|
||||
region.origin = 0;
|
||||
region.size = d_size_bytes;
|
||||
cl_mem dT_d = clCreateSubBuffer(
|
||||
backend_ctx->A_s_d_max,
|
||||
0,
|
||||
CL_BUFFER_CREATE_TYPE_REGION,
|
||||
®ion,
|
||||
&err);
|
||||
// cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
|
||||
|
||||
// create images from the buffers
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
cl_mem q_d_image1D;
|
||||
cl_mem d_d_image1D;
|
||||
cl_mem qT_d_image1D;
|
||||
cl_mem dT_d_image1D;
|
||||
|
||||
cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT };
|
||||
cl_image_desc img_desc_1d;
|
||||
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 8 / 4;
|
||||
img_desc_1d.buffer = extra->q;
|
||||
q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
img_fmt_1d = { CL_RGBA, CL_FLOAT };
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 8 / 4;
|
||||
img_desc_1d.buffer = qT_d;
|
||||
qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
img_fmt_1d = { CL_RGBA, CL_FLOAT };
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 32 / 4 / 2;
|
||||
img_desc_1d.buffer = extra->d;
|
||||
d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
img_fmt_1d = { CL_RGBA, CL_FLOAT };
|
||||
memset(&img_desc_1d, 0, sizeof(img_desc_1d));
|
||||
img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
|
||||
img_desc_1d.image_width = M * K / 32 / 4 / 2;
|
||||
img_desc_1d.buffer = dT_d;
|
||||
dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
|
||||
CL_CHECK(err);
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
|
||||
// set up and call the transpose kernels
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
// weights
|
||||
int height_q = M / 8;
|
||||
int width_q = K / 8 / 4;
|
||||
kernel = backend_ctx->kernel_transpose_16;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q));
|
||||
|
||||
size_t local_size_q[3] = {4, 16, 1};
|
||||
size_t global_size_q[3] = {static_cast<size_t>(width_q), static_cast<size_t>(height_q), 1};
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
|
||||
// scales
|
||||
int height_s = M / 8;
|
||||
int width_s = K / 32 / 8;
|
||||
|
||||
kernel = backend_ctx->kernel_transpose_16;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s));
|
||||
|
||||
size_t local_size_s[3] = {4, 16, 1};
|
||||
size_t global_size_s[3] = {static_cast<size_t>(width_s), static_cast<size_t>(height_s), 1};
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
|
||||
// copy transposed buffer contents to original buffers
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
// weights
|
||||
CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
|
||||
// scales
|
||||
CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
|
||||
// deallocate transpose buffers
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
CL_CHECK(clReleaseMemObject(qT_d));
|
||||
CL_CHECK(clReleaseMemObject(dT_d));
|
||||
|
||||
// deallocate temporary images
|
||||
CL_CHECK(clReleaseMemObject(q_d_image1D));
|
||||
CL_CHECK(clReleaseMemObject(d_d_image1D));
|
||||
CL_CHECK(clReleaseMemObject(qT_d_image1D));
|
||||
CL_CHECK(clReleaseMemObject(dT_d_image1D));
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
// end transpose
|
||||
// <----------------------------------------------------------------------------------> //
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
return;
|
||||
#else // GGML_OPENCL_SMALL_ALLOC
|
||||
// Tensors should have been preallocated, therefore they should
|
||||
// already have ggml_tensor_extra_cl as extra.
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
|
@ -1677,7 +1417,6 @@ static void ggml_backend_opencl2_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
return;
|
||||
#endif // GGML_OPENCL_SMALL_ALLOC
|
||||
}
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
|
||||
|
|
|
@ -9296,11 +9296,7 @@ static bool llm_load_tensors(
|
|||
}
|
||||
}
|
||||
else {
|
||||
#ifdef GGML_USE_OPENCL
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft_for_weights(ctx, buft);
|
||||
#else // GGML_USE_OPENCL
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||
#endif
|
||||
if (buf == nullptr) {
|
||||
throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include <ggml.h>
|
||||
#include <ggml-alloc.h>
|
||||
#include <ggml-backend.h>
|
||||
#include <ggml-opencl2.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -466,19 +465,7 @@ struct test_case {
|
|||
// post-graph sentinel
|
||||
add_sentinel(ctx);
|
||||
|
||||
// allocate
|
||||
// We need to use this function to properly support GGML_OPENCL_SMALL_ALLOC.
|
||||
// In fact, `ggml_backend_alloc_ctx_tensors_from_buft_for_weights` is a
|
||||
// bit misdenomer. It is initially created for allocating weights. But
|
||||
// it can be used for allocating any tensors that needs small alloc.
|
||||
// Something like `ggml_backend_alloc_ctx_tensors_from_buft_2` or
|
||||
// `ggml_backend_alloc_ctx_tensors_from_buft_small_alloc` would be better.
|
||||
//
|
||||
// This is intrusive. We intend to remove SMALL_ALLOC path once the we fully
|
||||
// migrate to the non SMALL_ALLOC path.
|
||||
ggml_backend_buffer_t buf = ggml_backend_is_opencl2(backend1) == false ? ggml_backend_alloc_ctx_tensors(ctx, backend1) :
|
||||
ggml_backend_alloc_ctx_tensors_from_buft_for_weights(
|
||||
ctx, ggml_backend_get_default_buffer_type(backend1));
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
|
||||
if (buf == NULL) {
|
||||
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
|
||||
ggml_free(ctx);
|
||||
|
@ -630,13 +617,7 @@ struct test_case {
|
|||
printf("%*s", last - len, "");
|
||||
|
||||
// allocate
|
||||
#ifdef GGML_USE_OPENCL
|
||||
ggml_backend_buffer_t buf =
|
||||
ggml_backend_alloc_ctx_tensors_from_buft_for_weights(
|
||||
ctx, ggml_backend_get_default_buffer_type(backend));
|
||||
#else
|
||||
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||
#endif
|
||||
if (buf == NULL) {
|
||||
printf("failed to allocate tensors\n");
|
||||
ggml_free(ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue