llama : add phi3 128K model support (#7225)
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6369bf0433
commit
201cc11afa
15 changed files with 484 additions and 233 deletions
49
ggml.h
49
ggml.h
|
@ -1465,6 +1465,7 @@ extern "C" {
|
|||
// if mode & 4 == 1, ChatGLM style
|
||||
//
|
||||
// b is an int32 vector with size a->ne[2], it contains the positions
|
||||
// c is freq factors (e.g. phi3-128k), (optional)
|
||||
GGML_API struct ggml_tensor * ggml_rope(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
@ -1483,10 +1484,11 @@ extern "C" {
|
|||
int n_ctx);
|
||||
|
||||
// custom RoPE
|
||||
GGML_API struct ggml_tensor * ggml_rope_custom(
|
||||
GGML_API struct ggml_tensor * ggml_rope_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
|
@ -1499,7 +1501,23 @@ extern "C" {
|
|||
float beta_slow);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
|
@ -1512,20 +1530,28 @@ extern "C" {
|
|||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
float beta_slow),
|
||||
"use ggml_rope_ext instead");
|
||||
|
||||
// compute correction dims for YaRN RoPE scaling
|
||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
|
||||
// xPos RoPE, in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down);
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow),
|
||||
"use ggml_rope_ext_inplace instead");
|
||||
|
||||
// compute correction dims for YaRN RoPE scaling
|
||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
|
||||
// rotary position embedding backward, i.e compute dx from dy
|
||||
// a - dy
|
||||
|
@ -1533,6 +1559,7 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue