Fix q2_k, improve code
This commit is contained in:
parent
6e20827f93
commit
fc8c823f34
1 changed files with 17 additions and 26 deletions
|
@ -15,7 +15,7 @@
|
|||
|
||||
#include "ggml.h"
|
||||
|
||||
#define CL_DMMV_BLOCK_SIZE 32;
|
||||
#define CL_DMMV_BLOCK_SIZE 32
|
||||
|
||||
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
||||
static std::string program_source = MULTILINE_QUOTE(
|
||||
|
@ -349,7 +349,6 @@ void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int i
|
|||
const uint32_t kmask1 = 0x03030303;
|
||||
const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
|
||||
uint32_t aux[3];
|
||||
uint32_t utmp[4];
|
||||
|
||||
int n = iqs/128;
|
||||
|
@ -361,18 +360,7 @@ void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int i
|
|||
__global const uint8_t * hm = x[ib].hmask + l;
|
||||
const int8_t * s = (const int8_t *)utmp + 8*n;
|
||||
|
||||
aux[0] |= x[ib].scales[0];
|
||||
aux[0] |= x[ib].scales[1] << 8;
|
||||
aux[0] |= x[ib].scales[2] << 16;
|
||||
aux[0] |= x[ib].scales[3] << 24;
|
||||
aux[1] |= x[ib].scales[4];
|
||||
aux[1] |= x[ib].scales[5] << 8;
|
||||
aux[1] |= x[ib].scales[6] << 16;
|
||||
aux[1] |= x[ib].scales[7] << 24;
|
||||
aux[2] |= x[ib].scales[8];
|
||||
aux[2] |= x[ib].scales[9] << 8;
|
||||
aux[2] |= x[ib].scales[10] << 16;
|
||||
aux[2] |= x[ib].scales[11] << 24;
|
||||
__global const uint32_t* aux = (__global const uint32_t*) x[ib].scales;
|
||||
|
||||
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
||||
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
||||
|
@ -470,7 +458,7 @@ void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int i
|
|||
const float d = vload_half(0, &x[ib].d);
|
||||
|
||||
__global const uint8_t * ql = x[ib].ql + 64*ip + il;
|
||||
const uint8_t * qh = x[ib].qh + 32*ip + il;
|
||||
__global const uint8_t * qh = x[ib].qh + 32*ip + il;
|
||||
__global const int8_t * sc = x[ib].scales + is;
|
||||
|
||||
*result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
|
||||
|
@ -514,7 +502,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
|
|||
std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
|
||||
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
|
||||
const int block_size = get_local_size(0);
|
||||
const int row = get_global_id(0) / block_size;
|
||||
const int row = get_group_id(0);
|
||||
const int tid = get_local_id(0);
|
||||
|
||||
const uint qk = QUANT_K;
|
||||
|
@ -556,11 +544,11 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
|
|||
std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
|
||||
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
|
||||
const int block_size = get_local_size(0);
|
||||
const int row = get_global_id(0) / block_size;
|
||||
const int row = get_group_id(0);
|
||||
const int tid = get_local_id(0);
|
||||
|
||||
const int iter_stride = 256;
|
||||
const int vals_per_iter = iter_stride;
|
||||
const int vals_per_iter = iter_stride / block_size;
|
||||
const int num_blocks_per_row = ncols / 256;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
|
||||
|
@ -574,7 +562,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
|
|||
|
||||
// dequantize
|
||||
float v;
|
||||
dot_kernel(x, ib, iqs, y + iybs, &v);
|
||||
DOT_KERNEL(x, ib, iqs, y + iybs, &v);
|
||||
tmp[tid] += v;
|
||||
}
|
||||
|
||||
|
@ -653,6 +641,10 @@ std::array<std::string, 2> mul_str_values = {
|
|||
"mul_f32", "float"
|
||||
};
|
||||
|
||||
std::array<std::string, 3> dmmv_k_str_keys = {
|
||||
"KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
|
||||
};
|
||||
|
||||
std::array<std::string, 15> dmmv_k_str_values = {
|
||||
"dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
|
||||
"dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
|
||||
|
@ -690,13 +682,12 @@ std::string generate_kernels() {
|
|||
}
|
||||
src << mul_kernel << '\n';
|
||||
}
|
||||
for (size_t i = 0; i < dmmv_k_str_values.size(); i += 3) {
|
||||
std::string dmmv_kernel = dequant_mul_mat_vec_k_template;
|
||||
//just apply quick template fn name replacement for the K quant DMMVs since sizes are known
|
||||
replace(dmmv_kernel, "KERNEL_NAME", dmmv_k_str_values[i]);
|
||||
replace(dmmv_kernel, "X_TYPE", dmmv_k_str_values[i + 1]);
|
||||
replace(dmmv_kernel, "dot_kernel", dmmv_k_str_values[i + 2]);
|
||||
src << dmmv_kernel << '\n';
|
||||
for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
|
||||
std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
|
||||
for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
|
||||
replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
|
||||
}
|
||||
src << dmmv_k_kernel << '\n';
|
||||
}
|
||||
|
||||
return src.str();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue