Fix dequant_mul_mat kernel

This commit is contained in:
0cc4m 2023-05-14 21:26:07 +02:00
parent 8795403de3
commit 883e587a04

View file

@ -155,9 +155,9 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float*
} }
__kernel void dequantize_mul_mat_vec(__global struct block_q4_0* x, __local float* tmp, __global float* y, __global float* dst, int ncols) { __kernel void dequantize_mul_mat_vec(__global struct block_q4_0* x, __local float* tmp, __global float* y, __global float* dst, int ncols) {
const int row = get_local_id(0);
const int tid = get_global_id(0);
const int block_size = get_local_size(0); const int block_size = get_local_size(0);
const int row = get_global_id(0) / block_size;
const int tid = get_local_id(0);
const uint qk = QK4_0; const uint qk = QK4_0;
const uint qr = QR4_0; const uint qr = QR4_0;
@ -666,7 +666,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
// compute // compute
// dequantize_mul_mat_vec(__global void * vx, __local float* tmp, __global float * y, __global float * dst, __global int ncols, __global int vx_type) { // dequantize_mul_mat_vec(__global void * vx, __local float* tmp, __global float * y, __global float * dst, __global int ncols, __global int vx_type) {
const size_t global = ne01; const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
const size_t local = CL_DMMV_BLOCK_SIZE; const size_t local = CL_DMMV_BLOCK_SIZE;
const cl_int ncols = ne00; const cl_int ncols = ne00;
CL_CHECK(clSetKernelArg(dequantize_mul_mat_vec_cl, 0, sizeof(cl_mem), &d_Q), "clSetKernelArg"); CL_CHECK(clSetKernelArg(dequantize_mul_mat_vec_cl, 0, sizeof(cl_mem), &d_Q), "clSetKernelArg");