From e92795f2f4d02c4dfb660b932477f6eacec17c2a Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 22 Jun 2023 14:06:13 -0600 Subject: [PATCH] Add CUDA and hopefully Metal support for p_scale --- ggml-cuda.cu | 3 +-- ggml-metal.m | 40 ++++++++++++++++++++-------------------- ggml-metal.metal | 3 ++- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a64547cd9..f2897c3ad 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1913,10 +1913,9 @@ inline void ggml_cuda_op_rope( const int mode = (int)((float *) src1->data)[2]; const float p_scale = ((float *) src1->data)[3]; GGML_ASSERT(mode == 0); - GGML_ASSERT(p_scale == 1.0); const float theta_scale = powf(10000.0, -2.0f/n_dims); - const float p = ((mode & 1) == 0 ? n_past + i02 : i02); + const float p = p_scale * ((mode & 1) == 0 ? n_past + i02 : i02); // compute rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); diff --git a/ggml-metal.m b/ggml-metal.m index 1798b68e6..2cb146b67 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -867,30 +867,30 @@ void ggml_metal_graph_compute( const int n_dims = (int)((float *) src1->data)[1]; const int mode = (int)((float *) src1->data)[2]; const float p_scale = ((float *) src1->data)[3]; - GGML_ASSERT(p_scale == 1.0 && "no Metal support for rope p_scale != 1.0"); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; - [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; + [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + [encoder setBytes:&p_scale length:sizeof( float) atIndex:21]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index d1e49222d..f50bfd811 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -615,6 +615,7 @@ kernel void kernel_rope( constant int & n_past, constant int & n_dims, constant int & mode, + constant float & p_scale, uint3 tpig[[thread_position_in_grid]]) { const int64_t i3 = tpig[2]; const int64_t i2 = tpig[1]; @@ -625,7 +626,7 @@ kernel void kernel_rope( const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); - float theta = (float)p; + float theta = p_scale * (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {