From 488c1fc778313b4e27976170da277657c794ab5d Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 20 Sep 2023 00:21:09 +0200 Subject: [PATCH] ggml-cuda : add rope f16, restore performance --- ggml-cuda.cu | 129 +++++++++++++++++++++++++++++++-------------------- ggml-cuda.h | 1 + ggml.c | 2 +- llama.cpp | 10 ++-- 4 files changed, 87 insertions(+), 55 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 14b1ecf7d..20d00a487 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -439,7 +439,6 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs - bool copied; }; // this is faster on Windows @@ -4357,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, // rope == RoPE == rotary positional embedding -static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale) { +template +static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4369,7 +4369,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c const int i = row*ncols + col; const int i2 = row/p_delta_rows; - const int p = pos != nullptr ? pos[i2] : 0; + const int p = has_pos ? pos[i2] : 0; const float p0 = p * freq_scale; const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); @@ -4382,8 +4382,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale) { +template +static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { @@ -4394,7 +4395,7 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - const int p = pos != nullptr ? pos[i2] : 0; + const int p = has_pos ? pos[i2] : 0; const float p0 = p * freq_scale; const float theta = p0*powf(theta_scale, col/2); const float sin_theta = sinf(theta); @@ -5371,22 +5372,32 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } -static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, +template +static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + if (pos == nullptr) { + rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } else { + rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } } -static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, +template +static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - rope_neox_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + if (pos == nullptr) { + rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } else { + rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } } static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, @@ -6036,7 +6047,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - float * src0_ddq_as_f32; + float * src0_ddq_as_f32 = nullptr; size_t src0_as = 0; if (src0->type != GGML_TYPE_F32) { @@ -6074,8 +6085,9 @@ inline void ggml_cuda_op_rope( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -6093,23 +6105,16 @@ inline void ggml_cuda_op_rope( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); const float theta_scale = powf(freq_base, -2.0f/n_dims); - // const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - int * pos = nullptr; + int32_t * pos = nullptr; if ((mode & 1) == 0) { + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - pos = (int *) src1_extra->data_device[id]; - if (!src1_extra->copied) { - CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream)); - src1_extra->copied = true; - } + int id; + CUDA_CHECK(cudaGetDevice(&id)); + pos = (int32_t *) src1_extra->data_device[id]; } const bool is_neox = mode & 2; @@ -6121,9 +6126,21 @@ inline void ggml_cuda_op_rope( rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else { + GGML_ASSERT(false); + } } else { - rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else { + GGML_ASSERT(false); + } } (void) src1; @@ -6294,7 +6311,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s } } -void ggml_cuda_set_peer_access(const int n_tokens) { +static void ggml_cuda_set_peer_access(const int n_tokens) { static bool peer_access_enabled = false; const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; @@ -6622,27 +6639,27 @@ static void ggml_cuda_op_mul_mat( } } -void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); } -void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } -void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); } -void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); } -void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); } -void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); } @@ -6663,7 +6680,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te return false; } -void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation @@ -6692,7 +6709,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } -void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); @@ -6726,7 +6743,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } -void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; @@ -6770,11 +6787,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ } } -void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } -void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -6822,29 +6839,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens (void) dst; } -void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_cpy(src0, dst, nullptr); (void) src1; } -void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf); } -void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max); } -void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope); } -void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } -void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; (void) dst; @@ -6967,11 +6984,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { return extra; } -void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { +static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { if (scratch && g_scratch_size == 0) { return; } + tensor->backend = GGML_BACKEND_GPU; + // recursively assign CUDA buffers until a compute tensor is found if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src[0]->op; @@ -6983,8 +7002,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); } - tensor->backend = GGML_BACKEND_GPU; - if (scratch && no_alloc) { return; } @@ -7069,6 +7086,16 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) tensor->extra = extra; } +void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + CUDA_CHECK(cudaMemcpyAsync(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice, main_stream)); +} + void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { ggml_cuda_assign_buffers_impl(tensor, true, false, false); } diff --git a/ggml-cuda.h b/ggml-cuda.h index a72e82069..fda704b66 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -31,6 +31,7 @@ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor); GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset); +GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor); GGML_API void ggml_cuda_set_main_device(int main_device); GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q); diff --git a/ggml.c b/ggml.c index 207561794..764865939 100644 --- a/ggml.c +++ b/ggml.c @@ -6343,7 +6343,7 @@ static struct ggml_tensor * ggml_cpy_impl( } // make a view of the destination - struct ggml_tensor * result = ggml_view_tensor(ctx, b); + struct ggml_tensor * result = b->op == GGML_OP_NONE ? b : ggml_view_tensor(ctx, b); if (strlen(b->name) > 0) { ggml_format_name(result, "%s (copy of %s)", b->name, a->name); } else { diff --git a/llama.cpp b/llama.cpp index 12b8c49d0..61dbf5ca2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1256,10 +1256,10 @@ static bool llama_kv_cache_init( (void) n_gpu_layers; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer + 1) { + if (n_gpu_layers > (int)n_layer + 1) { ggml_cuda_assign_buffers_no_scratch(cache.v); } - if (n_gpu_layers > n_layer + 2) { + if (n_gpu_layers > (int)n_layer + 2) { ggml_cuda_assign_buffers_no_scratch(cache.k); } #endif // GGML_USE_CUBLAS @@ -2619,7 +2619,7 @@ static struct ggml_cgraph * llm_build_llama( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : std::max(1, (int)kv_self.cell_max); const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; @@ -2700,6 +2700,7 @@ static struct ggml_cgraph * llm_build_llama( // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; @@ -2722,6 +2723,7 @@ static struct ggml_cgraph * llm_build_llama( // KQ_pos - contains the positions struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); ggml_allocr_alloc(lctx.alloc, KQ_pos); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) KQ_pos->data; @@ -2734,6 +2736,7 @@ static struct ggml_cgraph * llm_build_llama( if (do_rope_shift) { struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); ggml_allocr_alloc(lctx.alloc, K_shift); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) K_shift->data; @@ -4116,6 +4119,7 @@ static int llama_decode_internal( ggml_tensor * node = gf->leafs[i]; if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + ggml_cuda_copy_to_device(node); } }