diff --git a/Makefile b/Makefile index 8f57e406d..ebfc91d35 100644 --- a/Makefile +++ b/Makefile @@ -248,7 +248,6 @@ expose.o: expose.cpp expose.h gpttype_adapter.o: gpttype_adapter.cpp $(CXX) $(CXXFLAGS) -c $< -o $@ - gpttype_adapter_clblast.o: gpttype_adapter.cpp $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@ diff --git a/ggml-opencl-legacy.h b/ggml-opencl-legacy.h index ff7f4f0c2..588a5bab6 100644 --- a/ggml-opencl-legacy.h +++ b/ggml-opencl-legacy.h @@ -6,17 +6,6 @@ extern "C" { #endif -enum ggml_blas_order { - GGML_BLAS_ORDER_ROW_MAJOR = 101, - GGML_BLAS_ORDER_COLUMN_MAJOR = 102, -}; - -enum ggml_blas_op { - GGML_BLAS_OP_N = 111, - GGML_BLAS_OP_T = 112, - GGML_BLAS_OP_C = 113, -}; - void ggml_cl_init_legacy(void); void ggml_cl_sgemm_wrapper_legacy(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 1367408b8..79b20f8d2 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -1,6 +1,8 @@ #include "ggml-opencl.h" +#include #include +#include #define CL_TARGET_OPENCL_VERSION 110 #include @@ -15,30 +17,30 @@ #define CL_DMMV_BLOCK_SIZE 32; #define MULTILINE_QUOTE(...) #__VA_ARGS__ -const char * clblast_dequant = MULTILINE_QUOTE( +std::string program_source = MULTILINE_QUOTE( typedef char int8_t; typedef uchar uint8_t; typedef int int32_t; typedef uint uint32_t; -constant uint GGML_TYPE_Q4_0 = 2; -constant uint GGML_TYPE_Q4_1 = 3; -constant uint GGML_TYPE_Q5_0 = 6; -constant uint GGML_TYPE_Q5_1 = 7; -constant uint GGML_TYPE_Q8_0 = 8; -constant uint GGML_TYPE_Q8_1 = 9; +const uint GGML_TYPE_Q4_0 = 2; +const uint GGML_TYPE_Q4_1 = 3; +const uint GGML_TYPE_Q5_0 = 6; +const uint GGML_TYPE_Q5_1 = 7; +const uint GGML_TYPE_Q8_0 = 8; +const uint GGML_TYPE_Q8_1 = 9; -constant uint QK4_0 = 32; -constant uint QR4_0 = 2; +const uint QK4_0 = 32; +const uint QR4_0 = 2; struct block_q4_0 { float d; uint8_t qs[QK4_0 / 2]; }; -constant uint QK4_1 = 32; -constant uint QR4_1 = 2; +const uint QK4_1 = 32; +const uint QR4_1 = 2; struct block_q4_1 { float d; @@ -46,8 +48,8 @@ struct block_q4_1 uint8_t qs[QK4_1 / 2]; }; -constant uint QK5_0 = 32; -constant uint QR5_0 = 2; +const uint QK5_0 = 32; +const uint QR5_0 = 2; struct __attribute__ ((packed)) block_q5_0 { half d; @@ -55,8 +57,8 @@ struct __attribute__ ((packed)) block_q5_0 uint8_t qs[QK5_0 / 2]; }; -constant uint QK5_1 = 32; -constant uint QR5_1 = 2; +const uint QK5_1 = 32; +const uint QR5_1 = 2; struct block_q5_1 { half d; @@ -65,8 +67,8 @@ struct block_q5_1 uint8_t qs[QK5_1 / 2]; }; -constant uint QK8_0 = 32; -constant uint QR8_0 = 1; +const uint QK8_0 = 32; +const uint QR8_0 = 1; struct block_q8_0 { float d; @@ -82,7 +84,7 @@ __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) { __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) { - constant uint qk = QK4_0; + const uint qk = QK4_0; const uint i = get_global_id(0) / qk; const uint j = get_local_id(0); @@ -97,7 +99,7 @@ __kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* } __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) { - constant uint qk = QK4_1; + const uint qk = QK4_1; const uint i = get_global_id(0) / qk; const uint j = get_local_id(0); @@ -113,7 +115,7 @@ __kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* } __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) { - constant uint qk = QK5_0; + const uint qk = QK5_0; const uint i = get_global_id(0) / qk; const uint j = get_local_id(0); @@ -133,7 +135,7 @@ __kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* } __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) { - constant uint qk = QK5_1; + const uint qk = QK5_1; const uint i = get_global_id(0) / qk; const uint j = get_local_id(0); @@ -154,7 +156,7 @@ __kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* } __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) { - constant uint qk = QK8_0; + const uint qk = QK8_0; const uint i = get_global_id(0) / qk; const uint j = get_local_id(0); @@ -173,47 +175,6 @@ void dequantize_q4_0(__global const struct block_q4_0* x, const int ib, const in *v0 = (vi0 - 8)*d; *v1 = (vi1 - 8)*d; } - -__kernel void dequantize_mul_mat_vec_q4_0(__global struct block_q4_0* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { - const int block_size = get_local_size(0); - const int row = get_global_id(0) / block_size; - const int tid = get_local_id(0); - - const uint qk = QK4_0; - const uint qr = QR4_0; - - const int y_offset = qr == 1 ? 1 : qk/2; - - tmp[tid] = 0; - - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index - const int iybs = col - col%qk; // y block start index - - // dequantize - float v0, v1; - dequantize_q4_0(x, ib, iqs, &v0, &v1); - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} - void dequantize_q4_1(__global const struct block_q4_1* x, const int ib, const int iqs, float* v0, float* v1) { const float d = x[ib].d; const float m = x[ib].m; @@ -226,46 +187,6 @@ void dequantize_q4_1(__global const struct block_q4_1* x, const int ib, const in *v0 = vi0*d + m; *v1 = vi1*d + m; } -__kernel void dequantize_mul_mat_vec_q4_1(__global struct block_q4_1* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { - const int block_size = get_local_size(0); - const int row = get_global_id(0) / block_size; - const int tid = get_local_id(0); - - const uint qk = QK4_1; - const uint qr = QR4_1; - - const int y_offset = qr == 1 ? 1 : qk/2; - - tmp[tid] = 0; - - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index - const int iybs = col - col%qk; // y block start index - - // dequantize - float v0, v1; - dequantize_q4_1(x, ib, iqs, &v0, &v1); - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} - void dequantize_q5_0(__global const struct block_q5_0* x, const int ib, const int iqs, float* v0, float* v1) { const float d = vload_half(0, (__global half*) &x[ib].d); @@ -280,46 +201,6 @@ void dequantize_q5_0(__global const struct block_q5_0* x, const int ib, const in *v0 = x0*d; *v1 = x1*d; } -__kernel void dequantize_mul_mat_vec_q5_0(__global struct block_q5_0* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { - const int block_size = get_local_size(0); - const int row = get_global_id(0) / block_size; - const int tid = get_local_id(0); - - const uint qk = QK5_0; - const uint qr = QR5_0; - - const int y_offset = qr == 1 ? 1 : qk/2; - - tmp[tid] = 0; - - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index - const int iybs = col - col%qk; // y block start index - - // dequantize - float v0, v1; - dequantize_q5_0(x, ib, iqs, &v0, &v1); - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} - void dequantize_q5_1(__global const struct block_q5_1* x, const int ib, const int iqs, float* v0, float* v1) { const float d = vload_half(0, (__global half*) &x[ib].d); const float m = vload_half(0, (__global half*) &x[ib].m); @@ -335,46 +216,6 @@ void dequantize_q5_1(__global const struct block_q5_1* x, const int ib, const in *v0 = x0*d + m; *v1 = x1*d + m; } -__kernel void dequantize_mul_mat_vec_q5_1(__global struct block_q5_1* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { - const int block_size = get_local_size(0); - const int row = get_global_id(0) / block_size; - const int tid = get_local_id(0); - - const uint qk = QK5_1; - const uint qr = QR5_1; - - const int y_offset = qr == 1 ? 1 : qk/2; - - tmp[tid] = 0; - - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index - const int iybs = col - col%qk; // y block start index - - // dequantize - float v0, v1; - dequantize_q5_1(x, ib, iqs, &v0, &v1); - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} - void dequantize_q8_0(__global const struct block_q8_0* x, const int ib, const int iqs, float* v0, float* v1) { const float d = x[ib].d; @@ -384,13 +225,20 @@ void dequantize_q8_0(__global const struct block_q8_0* x, const int ib, const in *v0 = vi0*d; *v1 = vi1*d; } -__kernel void dequantize_mul_mat_vec_q8_0(__global struct block_q8_0* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { +void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float* v1){ + *v0 = vload_half(0, (__global half*) &x[ib + 0]); + *v1 = vload_half(0, (__global half*) &x[ib + 1]); +} +); + +std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE( +__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { const int block_size = get_local_size(0); const int row = get_global_id(0) / block_size; const int tid = get_local_id(0); - const uint qk = QK8_0; - const uint qr = QR8_0; + const uint qk = QUANT_K; + const uint qr = QUANT_R; const int y_offset = qr == 1 ? 1 : qk/2; @@ -404,51 +252,7 @@ __kernel void dequantize_mul_mat_vec_q8_0(__global struct block_q8_0* x, __local // dequantize float v0, v1; - dequantize_q8_0(x, ib, iqs, &v0, &v1); - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} - -void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float* v1){ - *v0 = vload_half(0, (__global half*) &x[ib + 0]); - *v1 = vload_half(0, (__global half*) &x[ib + 1]); -} -__kernel void convert_mul_mat_vec_f16(__global half* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { - const int block_size = get_local_size(0); - const int row = get_global_id(0) / block_size; - const int tid = get_local_id(0); - - const uint qk = 32; - const uint qr = 1; - - const int y_offset = qr == 1 ? 1 : qk/2; - - tmp[tid] = 0; - - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index - const int iybs = col - col%qk; // y block start index - - // convert - float v0, v1; - convert_f16(x, ib, iqs, &v0, &v1); + DEQUANT_FUNC(x, ib, iqs, &v0, &v1); // matrix multiplication tmp[tid] += v0 * y[iybs + iqs + 0]; @@ -469,6 +273,41 @@ __kernel void convert_mul_mat_vec_f16(__global half* x, __local float* tmp, __gl } ); +std::array dequant_mul_mat_vec_str_keys = { + "KERNEL_NAME", "X_TYPE", "QUANT_K", "QUANT_R", "DEQUANT_FUNC" +}; + +std::array dequant_mul_mat_vec_str_values = { + "dequantize_mul_mat_vec_q4_0", "struct block_q4_0", "QK4_0", "QR4_0", "dequantize_q4_0", + "dequantize_mul_mat_vec_q4_1", "struct block_q4_1", "QK4_1", "QR4_1", "dequantize_q4_1", + "dequantize_mul_mat_vec_q5_0", "struct block_q5_0", "QK5_0", "QR5_0", "dequantize_q5_0", + "dequantize_mul_mat_vec_q5_1", "struct block_q5_1", "QK5_1", "QR5_1", "dequantize_q5_1", + "dequantize_mul_mat_vec_q8_0", "struct block_q8_0", "QK8_0", "QR8_0", "dequantize_q8_0", + "convert_mul_mat_vec_f16", "half", "32", "1", "convert_f16" +}; + +static std::string& sreplace(std::string& s, const std::string& from, const std::string& to) { + size_t pos = 0; + while ((pos = s.find(from, pos)) != std::string::npos) { + s.replace(pos, from.length(), to); + pos += to.length(); + } + return s; +} + +static std::string generate_kernels() { + std::stringstream src; + src << program_source << '\n'; + for (size_t i = 0; i < dequant_mul_mat_vec_str_values.size(); i += dequant_mul_mat_vec_str_keys.size()) { + std::string kernel = dequant_mul_mat_vec_template; + for (size_t j = 0; j < dequant_mul_mat_vec_str_keys.size(); j++) { + sreplace(kernel, dequant_mul_mat_vec_str_keys[j], dequant_mul_mat_vec_str_values[i + j]); + } + src << kernel << '\n'; + } + return src.str(); +} + #define CL_CHECK(err, name) \ do { \ cl_int err_ = (err); \ @@ -483,6 +322,8 @@ static cl_device_id device; static cl_context context; static cl_command_queue queue; static cl_program program; +static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; +static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; static cl_kernel convert_fp16_to_fp32_cl; static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl; static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl; @@ -560,7 +401,9 @@ void ggml_cl_init(void) { free(platforms); free(devices); - program = build_program_from_source(context, device, clblast_dequant); + std::string kernel_src = generate_kernels(); + + program = build_program_from_source(context, device, kernel_src.c_str()); // FP16 to FP32 kernel convert_fp16_to_fp32_cl = clCreateKernel(program, "convert_fp16_to_fp32", &err); @@ -593,6 +436,21 @@ void ggml_cl_init(void) { CL_CHECK(err, "clCreateKernel"); } +static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) { + if (req_size <= *cur_size) { + return; + } + + // Reallocate buffer with enough space + if (*cur_size > 0) { + clReleaseMemObject(*buf); + } + cl_int err; + *buf = clCreateBuffer(context, flags, req_size, NULL, &err); + *cur_size = req_size; + CL_CHECK(err, "clCreateBuffer"); +} + static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -1014,7 +872,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CL)) { + (src0->backend == GGML_BACKEND_CL)) { return true; } @@ -1090,3 +948,111 @@ void ggml_cl_transform_tensor(ggml_tensor * tensor) { tensor->data = d_Q; tensor->backend = GGML_BACKEND_CL; } + +void ggml_cl_sgemm_wrapper( + const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, + const int m, const int n, const int k, + const float alpha, const void *host_a, const int lda, + const float *host_b, const int ldb, const float beta, + float *host_c, const int ldc, const int btype) { + cl_int err = 0; + + cl_kernel * kernel = ggml_get_to_fp32_cl((ggml_type)btype); + size_t global = n * k, local, size_qb; + bool dequant; + + switch (btype) { + case GGML_TYPE_F32: + dequant = false; + break; + case GGML_TYPE_Q4_0: + dequant = true; + local = 16; + size_qb = global * (sizeof(float) + local) / 32; + break; + case GGML_TYPE_Q4_1: + dequant = true; + local = 16; + size_qb = global * (sizeof(float) * 2 + local) / 32; + break; + case GGML_TYPE_Q5_0: + dequant = true; + local = 16; + size_qb = global * (sizeof(ggml_fp16_t) + sizeof(uint32_t) + local) / 32; + break; + case GGML_TYPE_Q5_1: + dequant = true; + local = 16; + size_qb = global * (sizeof(ggml_fp16_t) * 2 + sizeof(uint32_t) + local) / 32; + break; + case GGML_TYPE_Q8_0: + dequant = true; + local = 32; + size_qb = global * (sizeof(float) + local) / 32; + break; + default: + fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype); + abort(); + } + + const size_t size_a = m * k * sizeof(float); + const size_t size_b = n * k * sizeof(float); + const size_t size_c = m * n * sizeof(float); + + // Prepare buffers + ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a); + if (dequant) { + ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb); + } + ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b); + ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c); + + cl_event ev_a, ev_qb, ev_b; + + if (dequant) { + err = clSetKernelArg(*kernel, 0, sizeof(cl_mem), &cl_buffer_qb); + err |= clSetKernelArg(*kernel, 1, sizeof(cl_mem), &cl_buffer_b); + CL_CHECK(err, "clSetKernelArg"); + err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb); + CL_CHECK(err, "clEnqueueWriteBuffer qb"); + } else { + err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b); + CL_CHECK(err, "clEnqueueWriteBuffer b"); + } + + err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); + CL_CHECK(err, "clEnqueueWriteBuffer a"); + if (dequant) { + err = clEnqueueNDRangeKernel(queue, *kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b); + CL_CHECK(err, "clEnqueueNDRangeKernel"); + clReleaseEvent(ev_qb); + } + clWaitForEvents(1, &ev_a); + clWaitForEvents(1, &ev_b); + clReleaseEvent(ev_a); + clReleaseEvent(ev_b); + + cl_event ev_sgemm; + CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order, + (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, &ev_sgemm); + + if (status != CLBlastSuccess) { + fprintf(stderr, "Error: CLBlast SGEMM %d\n", status); + abort(); + } + + cl_event ev_c; + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c); + + // Wait for completion + clWaitForEvents(1, &ev_c); + clReleaseEvent(ev_sgemm); + clReleaseEvent(ev_c); +} \ No newline at end of file diff --git a/ggml-opencl.h b/ggml-opencl.h index 5a1a50093..e0c1d6957 100644 --- a/ggml-opencl.h +++ b/ggml-opencl.h @@ -6,6 +6,17 @@ extern "C" { #endif +enum ggml_blas_order { + GGML_BLAS_ORDER_ROW_MAJOR = 101, + GGML_BLAS_ORDER_COLUMN_MAJOR = 102, +}; + +enum ggml_blas_op { + GGML_BLAS_OP_N = 111, + GGML_BLAS_OP_T = 112, + GGML_BLAS_OP_C = 113, +}; + void ggml_cl_init(void); bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); @@ -17,6 +28,8 @@ void ggml_cl_host_free(void * ptr); void ggml_cl_transform_tensor(struct ggml_tensor * tensor); +void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); + #ifdef __cplusplus } #endif