KV cache v works, perf. bad, # after 64 tokens
This commit is contained in:
parent
9a85d913ee
commit
cf5ae8635a
2 changed files with 169 additions and 9 deletions
163
ggml-cuda.cu
163
ggml-cuda.cu
|
@ -761,7 +761,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
||||||
const int col_x = col_x0 + threadIdx.x;
|
const int col_x = col_x0 + threadIdx.x;
|
||||||
|
|
||||||
if (col_x >= ncols_x) {
|
if (col_x >= ncols_x) {
|
||||||
return;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// x is transposed and permuted
|
// x is transposed and permuted
|
||||||
|
@ -792,6 +792,52 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void mul_mat_vec_nc_f16_f32(
|
||||||
|
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
||||||
|
const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
|
||||||
|
|
||||||
|
const half * x = (half *) vx;
|
||||||
|
|
||||||
|
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||||
|
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||||
|
|
||||||
|
const int nrows_y = ncols_x;
|
||||||
|
const int nrows_dst = nrows_x;
|
||||||
|
const int row_dst = row_x;
|
||||||
|
|
||||||
|
const int idst = channel*nrows_dst + row_dst;
|
||||||
|
|
||||||
|
float tmp = 0.0f;
|
||||||
|
|
||||||
|
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
|
||||||
|
const int col_x = col_x0 + threadIdx.x;
|
||||||
|
|
||||||
|
if (col_x >= ncols_x) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
|
||||||
|
const float xi = __half2float(x[ix]);
|
||||||
|
|
||||||
|
const int row_y = col_x;
|
||||||
|
|
||||||
|
const int iy = channel*nrows_y + row_y;
|
||||||
|
|
||||||
|
tmp += xi * y[iy];
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums and write back result
|
||||||
|
__syncthreads();
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
dst[idst] = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
|
||||||
|
@ -1085,12 +1131,21 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y, cudaStream_t stream) {
|
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y, cudaStream_t stream) {
|
||||||
const int block_num_x = (ncols_x + WARP_SIZE - 1) / WARP_SIZE;
|
const dim3 block_nums(1, nrows_x, nchannels_x);
|
||||||
const dim3 block_nums(block_num_x, nrows_x, nchannels_x);
|
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y);
|
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
||||||
|
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
|
||||||
|
const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const dim3 block_nums(1, nrows_x, nchannels_x);
|
||||||
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
|
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
|
||||||
|
(vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cpy_f32_f16_cuda(
|
static void ggml_cpy_f32_f16_cuda(
|
||||||
const char * cx, char * cdst, const int ne,
|
const char * cx, char * cdst, const int ne,
|
||||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||||
|
@ -2056,10 +2111,108 @@ void ggml_cuda_mul_mat_p021_f16_f32(const ggml_tensor * src0, const ggml_tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_mul_mat_vec_nc_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||||
|
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
|
||||||
|
GGML_ASSERT(!ggml_is_permuted(src0));
|
||||||
|
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
// GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU);
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
GGML_ASSERT(ne03 == 1);
|
||||||
|
const size_t src0_size = ggml_nbytes(src0);
|
||||||
|
|
||||||
|
const int64_t nb01 = src0->nb[1];
|
||||||
|
const int64_t nb02 = src0->nb[2];
|
||||||
|
|
||||||
|
const int64_t ne12 = src1->ne[2];
|
||||||
|
const int64_t ne13 = src1->ne[3];
|
||||||
|
GGML_ASSERT(ne13 == 1);
|
||||||
|
GGML_ASSERT(ne12 == ne02);
|
||||||
|
const size_t src1_size = ggml_nbytes(src1);
|
||||||
|
|
||||||
|
const int64_t ne0 = dst->ne[0];
|
||||||
|
const int64_t ne1 = dst->ne[1];
|
||||||
|
const int64_t ne2 = dst->ne[2];
|
||||||
|
const int64_t ne3 = dst->ne[3];
|
||||||
|
GGML_ASSERT(ne2 == ne02);
|
||||||
|
const size_t dst_size = ggml_nbytes(dst);
|
||||||
|
|
||||||
|
const int64_t nb1 = dst->nb[1];
|
||||||
|
const int64_t nb2 = dst->nb[2];
|
||||||
|
const int64_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||||
|
|
||||||
|
void * src0_ddq;
|
||||||
|
float * src1_ddf;
|
||||||
|
float * dst_ddf;
|
||||||
|
if (src0->backend == GGML_BACKEND_CPU) {
|
||||||
|
CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size));
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||||
|
} else {
|
||||||
|
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||||
|
src0_ddq = src0_extra->data_device[g_main_device];
|
||||||
|
// CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0)));
|
||||||
|
}
|
||||||
|
if (src1->backend == GGML_BACKEND_CPU) {
|
||||||
|
CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size));
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main));
|
||||||
|
} else {
|
||||||
|
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||||
|
src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||||
|
// CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1)));
|
||||||
|
}
|
||||||
|
if (dst->backend == GGML_BACKEND_CPU) {
|
||||||
|
CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size));
|
||||||
|
} else {
|
||||||
|
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||||
|
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||||
|
}
|
||||||
|
|
||||||
|
const int row_stride_x = nb01 / sizeof(half);
|
||||||
|
const int channel_stride_x = nb02 / sizeof(half);
|
||||||
|
|
||||||
|
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
if (dst->backend == GGML_BACKEND_CPU) {
|
||||||
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||||
|
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||||
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||||
|
const int64_t i = i3*ne2*ne1 + i2*ne1 + i1;
|
||||||
|
float * dst_ddf_i = dst_ddf + i*ne0;
|
||||||
|
float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
if (src0->backend == GGML_BACKEND_CPU) {
|
||||||
|
CUDA_CHECK(cudaFree(src0_ddq));
|
||||||
|
}
|
||||||
|
if (src1->backend == GGML_BACKEND_CPU) {
|
||||||
|
CUDA_CHECK(cudaFree(src1_ddf));
|
||||||
|
}
|
||||||
|
if (src1->backend == GGML_BACKEND_CPU) {
|
||||||
|
CUDA_CHECK(cudaFree(dst_ddf));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||||
ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst);
|
ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst);
|
||||||
} else if (src0->type == GGML_TYPE_F32) {
|
} else if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
|
||||||
|
ggml_cuda_mul_mat_vec_nc_f16_f32(src0, src1, dst);
|
||||||
|
}else if (src0->type == GGML_TYPE_F32) {
|
||||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
|
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
|
||||||
|
@ -2080,8 +2233,6 @@ void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_te
|
||||||
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
const int64_t ne = ggml_nelements(src0);
|
const int64_t ne = ggml_nelements(src0);
|
||||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
||||||
|
|
||||||
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
|
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
|
||||||
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
|
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
|
||||||
|
|
15
llama.cpp
15
llama.cpp
|
@ -895,7 +895,7 @@ static bool kv_cache_init(
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
ggml_cuda_assign_buffers_no_scratch(cache.k);
|
ggml_cuda_assign_buffers_no_scratch(cache.k);
|
||||||
// ggml_cuda_assign_buffers_no_scratch(cache.v);
|
ggml_cuda_assign_buffers_no_scratch(cache.v);
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -1391,15 +1391,23 @@ static bool llama_eval_internal(
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
{
|
{
|
||||||
// compute the transposed [N, n_embd] V matrix
|
// compute the transposed [N, n_embd] V matrix
|
||||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
|
|
||||||
|
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||||
|
offload_func(tmpv);
|
||||||
|
ggml_set_name(tmpv, "tmpv");
|
||||||
|
|
||||||
|
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N));
|
||||||
|
offload_func(Vcur);
|
||||||
ggml_set_name(Vcur, "Vcur");
|
ggml_set_name(Vcur, "Vcur");
|
||||||
|
|
||||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||||
offload_func(k);
|
offload_func(k);
|
||||||
ggml_set_name(k, "k");
|
ggml_set_name(k, "k");
|
||||||
|
|
||||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||||
( n_ctx)*ggml_element_size(kv_self.v),
|
( n_ctx)*ggml_element_size(kv_self.v),
|
||||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||||
|
offload_func(v);
|
||||||
ggml_set_name(v, "v");
|
ggml_set_name(v, "v");
|
||||||
|
|
||||||
// important: storing RoPE-ed version of K in the KV cache!
|
// important: storing RoPE-ed version of K in the KV cache!
|
||||||
|
@ -1444,6 +1452,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// KQ = soft_max(KQ_masked)
|
// KQ = soft_max(KQ_masked)
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||||
|
offload_func(KQ_soft_max);
|
||||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||||
|
|
||||||
// split cached V into n_head heads
|
// split cached V into n_head heads
|
||||||
|
@ -1453,7 +1462,7 @@ static bool llama_eval_internal(
|
||||||
n_ctx*ggml_element_size(kv_self.v),
|
n_ctx*ggml_element_size(kv_self.v),
|
||||||
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
||||||
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
||||||
// offload_func(V);
|
offload_func(V);
|
||||||
ggml_set_name(V, "V");
|
ggml_set_name(V, "V");
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue