From 887694acfd3f8fe9737e01537cab7090d202a84b Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 22 Jun 2023 08:18:01 -0600 Subject: [PATCH] Handle rope params in CUDA, Metal Bail out if p_scale != 1.0 n rope operation for the time being --- ggml-cuda.cu | 10 +++++++--- ggml-metal.m | 11 +++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 36a251ecc..a64547cd9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1906,10 +1906,14 @@ inline void ggml_cuda_op_rope( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; + assert(src1->type == GGML_TYPE_F32); + assert(ggml_nelements(src1) == 4); + const int n_past = (int)((float *) src1->data)[0]; + const int n_dims = (int)((float *) src1->data)[1]; + const int mode = (int)((float *) src1->data)[2]; + const float p_scale = ((float *) src1->data)[3]; GGML_ASSERT(mode == 0); + GGML_ASSERT(p_scale == 1.0); const float theta_scale = powf(10000.0, -2.0f/n_dims); const float p = ((mode & 1) == 0 ? n_past + i02 : i02); diff --git a/ggml-metal.m b/ggml-metal.m index a7e104dc7..1798b68e6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -861,10 +861,13 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - const int n_past = ((int32_t *)(src1->data))[0]; + assert(src1->type == GGML_TYPE_F32); + assert(ggml_nelements(src1) == 4); + const int n_past = (int)((float *) src1->data)[0]; + const int n_dims = (int)((float *) src1->data)[1]; + const int mode = (int)((float *) src1->data)[2]; + const float p_scale = ((float *) src1->data)[3]; + GGML_ASSERT(p_scale == 1.0 && "no Metal support for rope p_scale != 1.0"); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];