Merge 50732c0698
into 8854044561
This commit is contained in:
commit
f0f9a1244c
1 changed files with 20 additions and 6 deletions
|
@ -217,10 +217,21 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
||||||
|
|
||||||
float4 s = {0.f, 0.f, 0.f, 0.f};
|
float4 s = {0.f, 0.f, 0.f, 0.f};
|
||||||
float smin = 0;
|
float smin = 0;
|
||||||
|
|
||||||
|
float4 y11 = *reinterpret_cast<const float4*>(y1+0);
|
||||||
|
float4 y12 = *reinterpret_cast<const float4*>(y1+32);
|
||||||
|
float4 y21 = *reinterpret_cast<const float4*>(y2+0);
|
||||||
|
float4 y22 = *reinterpret_cast<const float4*>(y2+32);
|
||||||
|
|
||||||
|
const float* p11 = &y11.x;
|
||||||
|
const float* p12 = &y12.x;
|
||||||
|
const float* p21 = &y21.x;
|
||||||
|
const float* p22 = &y22.x;
|
||||||
|
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
|
s.x += p11[l] * q4[l+0]; s.y += p12[l] * q4[l+ 4];
|
||||||
s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
|
s.z += p21[l] * q4[l+8]; s.w += p22[l] * q4[l+12];
|
||||||
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
smin += p11[l] * sc[2] + p12[l] * sc[3] + p21[l] * sc[6] + p22[l] * sc[7];
|
||||||
}
|
}
|
||||||
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
|
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
|
||||||
#else
|
#else
|
||||||
|
@ -563,12 +574,15 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
|
||||||
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define BLOCK_DIM_X 32
|
||||||
|
#define BLOCK_DIM_Y 4
|
||||||
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
const int ny = 2*BLOCK_DIM_Y / K_QUANTS_PER_ITERATION;
|
||||||
|
constexpr int grid_scale = BLOCK_DIM_X/32;
|
||||||
const int block_num_y = (nrows + ny - 1) / ny;
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
const dim3 block_nums((block_num_y+grid_scale-1)/grid_scale, 1, 1);
|
||||||
const dim3 block_dims(32, ny, 1);
|
const dim3 block_dims(BLOCK_DIM_X, ny, 1);
|
||||||
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue