From 8c6bd319dbaedc71095b22487cc3f39478ff9aed Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Wed, 7 Jun 2023 11:47:52 +0200 Subject: [PATCH] Fixed CUDA RoPE --- ggml-cuda.cu | 5 ++--- llama.cpp | 8 ++++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3b9a5ddfb..dc0a48bfb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1783,7 +1783,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) { - ggml_cuda_assign_buffers(tensor); + ggml_cuda_assign_buffers(tensor->src0); } const size_t size = ggml_nbytes(tensor); @@ -1800,8 +1800,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { CUDA_CHECK(cudaSetDevice(g_main_device)); if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra; - extra->data_device[g_main_device] = src0_extra->data_device; - GGML_ASSERT(false); + extra->data_device[g_main_device] = src0_extra->data_device[g_main_device]; } else { char * data = (char *) g_scratch_buffer; if (data == nullptr) { diff --git a/llama.cpp b/llama.cpp index c7a333642..9249a48d5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1366,17 +1366,21 @@ static bool llama_eval_internal( { // compute Q and K and RoPE them struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - // offload_func(tmpq); + offload_func(tmpq); ggml_set_name(tmpq, "tmpq"); struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - // offload_func(tmpk); + offload_func(tmpk); ggml_set_name(tmpk, "tmpk"); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0); + offload_func(Kcur); + Kcur->backend = GGML_BACKEND_CPU; ggml_set_name(Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0); + offload_func(Qcur); + Qcur->backend = GGML_BACKEND_CPU; ggml_set_name(Qcur, "Qcur"); // store key and value to memory