This commit is contained in:
Luca 2024-06-26 15:38:32 +02:00 committed by GitHub
commit f0f9a1244c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -217,10 +217,21 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
float4 y11 = *reinterpret_cast<const float4*>(y1+0);
float4 y12 = *reinterpret_cast<const float4*>(y1+32);
float4 y21 = *reinterpret_cast<const float4*>(y2+0);
float4 y22 = *reinterpret_cast<const float4*>(y2+32);
const float* p11 = &y11.x;
const float* p12 = &y12.x;
const float* p21 = &y21.x;
const float* p22 = &y22.x;
for (int l = 0; l < 4; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
s.x += p11[l] * q4[l+0]; s.y += p12[l] * q4[l+ 4];
s.z += p21[l] * q4[l+8]; s.w += p22[l] * q4[l+12];
smin += p11[l] * sc[2] + p12[l] * sc[3] + p21[l] * sc[6] + p22[l] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#else
@ -563,12 +574,15 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
#define BLOCK_DIM_X 32
#define BLOCK_DIM_Y 4
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int ny = 2*BLOCK_DIM_Y / K_QUANTS_PER_ITERATION;
constexpr int grid_scale = BLOCK_DIM_X/32;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
const dim3 block_nums((block_num_y+grid_scale-1)/grid_scale, 1, 1);
const dim3 block_dims(BLOCK_DIM_X, ny, 1);
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}