Add CUDA and hopefully Metal support for p_scale
This commit is contained in:
parent
887694acfd
commit
e92795f2f4
3 changed files with 23 additions and 23 deletions
|
@ -1913,10 +1913,9 @@ inline void ggml_cuda_op_rope(
|
|||
const int mode = (int)((float *) src1->data)[2];
|
||||
const float p_scale = ((float *) src1->data)[3];
|
||||
GGML_ASSERT(mode == 0);
|
||||
GGML_ASSERT(p_scale == 1.0);
|
||||
|
||||
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
||||
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
|
||||
const float p = p_scale * ((mode & 1) == 0 ? n_past + i02 : i02);
|
||||
|
||||
// compute
|
||||
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
|
||||
|
|
40
ggml-metal.m
40
ggml-metal.m
|
@ -867,30 +867,30 @@ void ggml_metal_graph_compute(
|
|||
const int n_dims = (int)((float *) src1->data)[1];
|
||||
const int mode = (int)((float *) src1->data)[2];
|
||||
const float p_scale = ((float *) src1->data)[3];
|
||||
GGML_ASSERT(p_scale == 1.0 && "no Metal support for rope p_scale != 1.0");
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&p_scale length:sizeof( float) atIndex:21];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
|
|
|
@ -615,6 +615,7 @@ kernel void kernel_rope(
|
|||
constant int & n_past,
|
||||
constant int & n_dims,
|
||||
constant int & mode,
|
||||
constant float & p_scale,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
const int64_t i3 = tpig[2];
|
||||
const int64_t i2 = tpig[1];
|
||||
|
@ -625,7 +626,7 @@ kernel void kernel_rope(
|
|||
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
|
||||
float theta = (float)p;
|
||||
float theta = p_scale * (float)p;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue