KV cache v works, perf. bad, # after 64 tokens

This commit is contained in:
JohannesGaessler 2023-06-12 10:07:27 +02:00
parent 9a85d913ee
commit cf5ae8635a
2 changed files with 169 additions and 9 deletions

View file

@ -761,7 +761,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
const int col_x = col_x0 + threadIdx.x;
if (col_x >= ncols_x) {
return;
break;
}
// x is transposed and permuted
@ -792,6 +792,52 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
}
}
static __global__ void mul_mat_vec_nc_f16_f32(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
const half * x = (half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
const int idst = channel*nrows_dst + row_dst;
float tmp = 0.0f;
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
const int col_x = col_x0 + threadIdx.x;
if (col_x >= ncols_x) {
break;
}
const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
const float xi = __half2float(x[ix]);
const int row_y = col_x;
const int iy = channel*nrows_y + row_y;
tmp += xi * y[iy];
}
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (threadIdx.x == 0) {
dst[idst] = tmp;
}
}
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
@ -1085,12 +1131,21 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int ncols_y, cudaStream_t stream) {
const int block_num_x = (ncols_x + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_nums(block_num_x, nrows_x, nchannels_x);
const dim3 block_nums(1, nrows_x, nchannels_x);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, ncols_y);
}
static void ggml_mul_mat_vec_nc_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
const dim3 block_nums(1, nrows_x, nchannels_x);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
(vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
}
static void ggml_cpy_f32_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@ -2056,10 +2111,108 @@ void ggml_cuda_mul_mat_p021_f16_f32(const ggml_tensor * src0, const ggml_tensor
}
}
void ggml_cuda_mul_mat_vec_nc_f16_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
// GGML_ASSERT(src0->backend == GGML_BACKEND_GPU && src1->backend == GGML_BACKEND_GPU);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
GGML_ASSERT(ne03 == 1);
const size_t src0_size = ggml_nbytes(src0);
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(ne12 == ne02);
const size_t src1_size = ggml_nbytes(src1);
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t ne3 = dst->ne[3];
GGML_ASSERT(ne2 == ne02);
const size_t dst_size = ggml_nbytes(dst);
const int64_t nb1 = dst->nb[1];
const int64_t nb2 = dst->nb[2];
const int64_t nb3 = dst->nb[3];
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
void * src0_ddq;
float * src1_ddf;
float * dst_ddf;
if (src0->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaMalloc(&src0_ddq, src0_size));
CUDA_CHECK(cudaMemcpyAsync(src0_ddq, src0->data, src0_size, cudaMemcpyHostToDevice, cudaStream_main));
} else {
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
src0_ddq = src0_extra->data_device[g_main_device];
// CUDA_CHECK(cudaMemset(src0_ddq, 0, ggml_nbytes(src0)));
}
if (src1->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaMalloc(&src1_ddf, src1_size));
CUDA_CHECK(cudaMemcpyAsync(src1_ddf, src1->data, src1_size, cudaMemcpyHostToDevice, cudaStream_main));
} else {
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
src1_ddf = (float *) src1_extra->data_device[g_main_device];
// CUDA_CHECK(cudaMemset(src1_ddf, 0, ggml_nbytes(src1)));
}
if (dst->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaMalloc(&dst_ddf, dst_size));
} else {
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
dst_ddf = (float *) dst_extra->data_device[g_main_device];
}
const int row_stride_x = nb01 / sizeof(half);
const int channel_stride_x = nb02 / sizeof(half);
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
CUDA_CHECK(cudaDeviceSynchronize());
if (dst->backend == GGML_BACKEND_CPU) {
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
for (int64_t i1 = 0; i1 < ne1; i1++) {
const int64_t i = i3*ne2*ne1 + i2*ne1 + i1;
float * dst_ddf_i = dst_ddf + i*ne0;
float * dhf_dst_i = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, ne0*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main));
}
}
}
}
CUDA_CHECK(cudaDeviceSynchronize());
if (src0->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaFree(src0_ddq));
}
if (src1->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaFree(src1_ddf));
}
if (src1->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaFree(dst_ddf));
}
}
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) {
} else if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
ggml_cuda_mul_mat_vec_nc_f16_f32(src0, src1, dst);
}else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
@ -2080,8 +2233,6 @@ void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_te
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);

View file

@ -895,7 +895,7 @@ static bool kv_cache_init(
#ifdef GGML_USE_CUBLAS
ggml_cuda_assign_buffers_no_scratch(cache.k);
// ggml_cuda_assign_buffers_no_scratch(cache.v);
ggml_cuda_assign_buffers_no_scratch(cache.v);
#endif // GGML_USE_CUBLAS
return true;
@ -1391,15 +1391,23 @@ static bool llama_eval_internal(
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func(tmpv);
ggml_set_name(tmpv, "tmpv");
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N));
offload_func(Vcur);
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
offload_func(k);
ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
offload_func(v);
ggml_set_name(v, "v");
// important: storing RoPE-ed version of K in the KV cache!
@ -1444,6 +1452,7 @@ static bool llama_eval_internal(
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
offload_func(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads
@ -1453,7 +1462,7 @@ static bool llama_eval_internal(
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
// offload_func(V);
offload_func(V);
ggml_set_name(V, "V");
#if 1