Handle rope params in CUDA, Metal
Bail out if p_scale != 1.0 n rope operation for the time being
This commit is contained in:
parent
4bf45a7dbe
commit
887694acfd
2 changed files with 14 additions and 7 deletions
10
ggml-cuda.cu
10
ggml-cuda.cu
|
@ -1906,10 +1906,14 @@ inline void ggml_cuda_op_rope(
|
|||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t i01_diff = i01_high - i01_low;
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
assert(src1->type == GGML_TYPE_F32);
|
||||
assert(ggml_nelements(src1) == 4);
|
||||
const int n_past = (int)((float *) src1->data)[0];
|
||||
const int n_dims = (int)((float *) src1->data)[1];
|
||||
const int mode = (int)((float *) src1->data)[2];
|
||||
const float p_scale = ((float *) src1->data)[3];
|
||||
GGML_ASSERT(mode == 0);
|
||||
GGML_ASSERT(p_scale == 1.0);
|
||||
|
||||
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
||||
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
|
||||
|
|
11
ggml-metal.m
11
ggml-metal.m
|
@ -861,10 +861,13 @@ void ggml_metal_graph_compute(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
|
||||
const int n_past = ((int32_t *)(src1->data))[0];
|
||||
assert(src1->type == GGML_TYPE_F32);
|
||||
assert(ggml_nelements(src1) == 4);
|
||||
const int n_past = (int)((float *) src1->data)[0];
|
||||
const int n_dims = (int)((float *) src1->data)[1];
|
||||
const int mode = (int)((float *) src1->data)[2];
|
||||
const float p_scale = ((float *) src1->data)[3];
|
||||
GGML_ASSERT(p_scale == 1.0 && "no Metal support for rope p_scale != 1.0");
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue