Fixed CUDA RoPE
This commit is contained in:
parent
74d4cfa343
commit
8c6bd319db
2 changed files with 8 additions and 5 deletions
|
@ -1783,7 +1783,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|||
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
||||
if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) {
|
||||
ggml_cuda_assign_buffers(tensor);
|
||||
ggml_cuda_assign_buffers(tensor->src0);
|
||||
}
|
||||
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
|
@ -1800,8 +1800,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
|||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
|
||||
extra->data_device[g_main_device] = src0_extra->data_device;
|
||||
GGML_ASSERT(false);
|
||||
extra->data_device[g_main_device] = src0_extra->data_device[g_main_device];
|
||||
} else {
|
||||
char * data = (char *) g_scratch_buffer;
|
||||
if (data == nullptr) {
|
||||
|
|
|
@ -1366,17 +1366,21 @@ static bool llama_eval_internal(
|
|||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
// offload_func(tmpq);
|
||||
offload_func(tmpq);
|
||||
ggml_set_name(tmpq, "tmpq");
|
||||
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
// offload_func(tmpk);
|
||||
offload_func(tmpk);
|
||||
ggml_set_name(tmpk, "tmpk");
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
offload_func(Kcur);
|
||||
Kcur->backend = GGML_BACKEND_CPU;
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
offload_func(Qcur);
|
||||
Qcur->backend = GGML_BACKEND_CPU;
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
|
||||
// store key and value to memory
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue