llama : add phi3 128K model support (#7225)
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6369bf0433
commit
201cc11afa
15 changed files with 484 additions and 233 deletions
121
ggml-metal.m
121
ggml-metal.m
|
@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
||||
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
||||
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
||||
|
||||
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
||||
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
||||
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
||||
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
||||
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
||||
|
||||
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
||||
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
||||
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
||||
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
|
||||
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
||||
|
||||
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
||||
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
||||
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
||||
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
||||
|
||||
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
||||
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
||||
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
||||
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
||||
|
||||
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
||||
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
||||
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
||||
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
||||
|
||||
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
||||
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
||||
|
@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
const int n_as = src0->ne[2];
|
||||
|
||||
// src2 = ids
|
||||
const int64_t ne20 = src2->ne[0];
|
||||
const int64_t ne21 = src2->ne[1];
|
||||
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
||||
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
||||
|
||||
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2->nb[1];
|
||||
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
||||
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
||||
|
||||
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
||||
|
||||
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
||||
|
@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
|
@ -2252,6 +2258,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
||||
|
||||
if (!is_neox) {
|
||||
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
|
@ -2263,33 +2278,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
||||
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
||||
if (id_src2 != nil) {
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
|
||||
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
|
||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
|
||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
|
||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
|
||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
|
||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
@ -2535,11 +2555,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||
|
||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
||||
|
||||
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
||||
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
||||
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue