ggml-metal: fix yarn rope

get the correct n_orig_ctx, and ignore the GLM only n_ctx
This commit is contained in:
Xiao-Yong Jin 2023-11-03 12:31:31 -05:00
parent abb77e7319
commit d40ab6a116

View file

@ -1403,7 +1403,8 @@ void ggml_metal_graph_compute(
const int n_past = ((int32_t *) dst->op_params)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));