llama : cont k-shift refactoring + normalize type names

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-24 10:28:44 +02:00
parent dd392191ca
commit 89b2a43cac
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 199 additions and 202 deletions

View file

@ -295,9 +295,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { invalid_param = true; break; }
} else if (arg == "--rope-scale") {
if (++i >= argc) {
@ -630,11 +630,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
}
std::string arg_next = argv[i];
if (arg_next == "none") {
params.split_mode = LLAMA_SPLIT_NONE;
params.split_mode = LLAMA_SPLIT_MODE_NONE;
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_LAYER;
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (arg_next == "row") {
params.split_mode = LLAMA_SPLIT_ROW;
params.split_mode = LLAMA_SPLIT_MODE_ROW;
} else {
invalid_param = true;
break;

View file

@ -61,7 +61,7 @@ struct gpt_params {
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
@ -75,7 +75,7 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
// // sampling parameters

View file

@ -157,9 +157,9 @@ static const char * output_format_str(output_formats format) {
static const char * split_mode_str(llama_split_mode mode) {
switch (mode) {
case LLAMA_SPLIT_NONE: return "none";
case LLAMA_SPLIT_LAYER: return "layer";
case LLAMA_SPLIT_ROW: return "row";
case LLAMA_SPLIT_MODE_NONE: return "none";
case LLAMA_SPLIT_MODE_LAYER: return "layer";
case LLAMA_SPLIT_MODE_ROW: return "row";
default: GGML_ASSERT(!"invalid split mode");
}
}
@ -193,7 +193,7 @@ static const cmd_params cmd_params_defaults = {
/* type_v */ {GGML_TYPE_F16},
/* n_threads */ {get_num_physical_cores()},
/* n_gpu_layers */ {99},
/* split_mode */ {LLAMA_SPLIT_LAYER},
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* mul_mat_q */ {true},
@ -358,11 +358,11 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
for (const auto & m : p) {
llama_split_mode mode;
if (m == "none") {
mode = LLAMA_SPLIT_NONE;
mode = LLAMA_SPLIT_MODE_NONE;
} else if (m == "layer") {
mode = LLAMA_SPLIT_LAYER;
mode = LLAMA_SPLIT_MODE_LAYER;
} else if (m == "row") {
mode = LLAMA_SPLIT_ROW;
mode = LLAMA_SPLIT_MODE_ROW;
} else {
invalid_param = true;
break;

View file

@ -2082,9 +2082,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
else { invalid_param = true; break; }
}
else if (arg == "--rope-freq-base")
@ -2208,15 +2208,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
std::string arg_next = argv[i];
if (arg_next == "none")
{
params.split_mode = LLAMA_SPLIT_NONE;
params.split_mode = LLAMA_SPLIT_MODE_NONE;
}
else if (arg_next == "layer")
{
params.split_mode = LLAMA_SPLIT_LAYER;
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
}
else if (arg_next == "row")
{
params.split_mode = LLAMA_SPLIT_ROW;
params.split_mode = LLAMA_SPLIT_MODE_ROW;
}
else {
invalid_param = true;

280
llama.cpp
View file

@ -850,9 +850,9 @@ struct LLM_TN {
//
static std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
};
static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
@ -862,7 +862,7 @@ static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
}
}
return LLAMA_ROPE_SCALING_UNSPECIFIED;
return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
}
static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
@ -1581,7 +1581,8 @@ struct llama_hparams {
bool causal_attn = true;
bool need_kq_pos = false;
uint32_t pooling_type = LLAMA_POOLING_NONE;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
@ -2311,7 +2312,7 @@ namespace GGUFMeta {
}
};
struct ArrayInfo{
struct ArrayInfo {
const gguf_type gt;
const size_t length;
const void * data;
@ -2330,7 +2331,7 @@ namespace GGUFMeta {
};
template<typename T>
class GKV: public GKV_Base<T> {
class GKV : public GKV_Base<T> {
GKV() = delete;
public:
@ -2353,39 +2354,39 @@ namespace GGUFMeta {
return "unknown";
}
static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) {
if (!override) { return false; }
if (override->tag == expected_type) {
static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
if (!ovrd) { return false; }
if (ovrd->tag == expected_type) {
LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
__func__, override_type_to_str(override->tag), override->key);
switch (override->tag) {
__func__, override_type_to_str(ovrd->tag), ovrd->key);
switch (ovrd->tag) {
case LLAMA_KV_OVERRIDE_BOOL: {
LLAMA_LOG_INFO("%s\n", override->bool_value ? "true" : "false");
LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false");
} break;
case LLAMA_KV_OVERRIDE_INT: {
LLAMA_LOG_INFO("%" PRId64 "\n", override->int_value);
LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value);
} break;
case LLAMA_KV_OVERRIDE_FLOAT: {
LLAMA_LOG_INFO("%.6f\n", override->float_value);
LLAMA_LOG_INFO("%.6f\n", ovrd->float_value);
} break;
default:
// Shouldn't be possible to end up here, but just in case...
throw std::runtime_error(
format("Unsupported attempt to override %s type for metadata key %s\n",
override_type_to_str(override->tag), override->key));
override_type_to_str(ovrd->tag), ovrd->key));
}
return true;
}
LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
__func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag));
__func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
return false;
}
template<typename OT>
static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) {
target = override->bool_value;
try_override(OT & target, const struct llama_model_kv_override * ovrd) {
if (validate_override(LLAMA_KV_OVERRIDE_BOOL, ovrd)) {
target = ovrd->bool_value;
return true;
}
return false;
@ -2393,9 +2394,9 @@ namespace GGUFMeta {
template<typename OT>
static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) {
target = override->int_value;
try_override(OT & target, const struct llama_model_kv_override * ovrd) {
if (validate_override(LLAMA_KV_OVERRIDE_INT, ovrd)) {
target = ovrd->int_value;
return true;
}
return false;
@ -2403,9 +2404,9 @@ namespace GGUFMeta {
template<typename OT>
static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) {
target = override->float_value;
try_override(T & target, const struct llama_model_kv_override * ovrd) {
if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, ovrd)) {
target = ovrd->float_value;
return true;
}
return false;
@ -2413,17 +2414,17 @@ namespace GGUFMeta {
template<typename OT>
static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override *override) {
try_override(T & target, const struct llama_model_kv_override * ovrd) {
(void)target;
(void)override;
if (!override) { return false; }
(void)ovrd;
if (!ovrd) { return false; }
// Currently, we should never end up here so it would be a bug if we do.
throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n",
override ? override->key : "NULL"));
ovrd ? ovrd->key : "NULL"));
}
static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) {
if (try_override<T>(target, override)) {
static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
if (try_override<T>(target, ovrd)) {
return true;
}
if (k < 0) { return false; }
@ -2431,12 +2432,12 @@ namespace GGUFMeta {
return true;
}
static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
return set(ctx, gguf_find_key(ctx, key), target, override);
static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
return set(ctx, gguf_find_key(ctx, key), target, ovrd);
}
static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
return set(ctx, key.c_str(), target, override);
static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
return set(ctx, key.c_str(), target, ovrd);
}
};
}
@ -2846,6 +2847,15 @@ struct llama_model_loader {
}
};
template<>
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
uint32_t tmp;
const bool found = get_key(kid, tmp, required);
result = (enum llama_pooling_type) tmp;
return found;
}
//
// load LLaMA models
//
@ -2924,6 +2934,7 @@ static const char * llama_model_type_name(e_model type) {
default: return "?B";
}
}
static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
switch (type) {
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
@ -2933,7 +2944,6 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
}
}
static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
model.arch = ml.get_arch();
if (model.arch == LLM_ARCH_UNKNOWN) {
@ -2997,7 +3007,7 @@ static void llm_load_hparams(
std::string rope_scaling("linear");
ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false);
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED);
// rope_freq_scale (inverse of the kv) is optional
float ropescale = 0.0f;
@ -3273,6 +3283,8 @@ static void llm_load_hparams(
if (hparams.f_max_alibi_bias > 0.0f) {
hparams.need_kq_pos = true;
}
hparams.rope_type = llama_rope_type(&model);
}
// TODO: This should probably be in llama.h
@ -3575,6 +3587,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
@ -3641,7 +3655,7 @@ static bool llm_load_tensors(
model.buft_layer[i] = llama_default_buffer_type_cpu(true);
}
if (split_mode == LLAMA_SPLIT_LAYER) {
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
// calculate the split points
int device_count = llama_get_device_count();
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
@ -3680,10 +3694,10 @@ static bool llm_load_tensors(
}
} else {
ggml_backend_buffer_type_t split_buft;
if (split_mode == LLAMA_SPLIT_ROW) {
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
} else {
// LLAMA_SPLIT_NONE or LLAMA_SPLIT_LAYER in backends where it is not supported
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
split_buft = llama_default_buffer_type_offload(main_gpu);
}
// assign the repeating layers
@ -4596,13 +4610,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
enum llm_rope_type : int {
LLM_ROPE_NONE = -1,
LLM_ROPE = 0,
LLM_ROPE_NEOX = 2,
LLM_ROPE_GLM = 4,
};
enum llm_ffn_op_type {
LLM_FFN_SILU,
LLM_FFN_GELU,
@ -4648,47 +4655,6 @@ static struct ggml_tensor * llm_build_inp_embd(
return inpL;
}
// Persimmon: n_rot = n_embd_head_k/2
// Other: n_rot = n_embd_head_k
static void llm_build_k_shift(
struct ggml_context * ctx,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
struct ggml_tensor * K_shift,
llm_rope_type rope_type,
int64_t n_ctx,
float freq_base,
float freq_scale,
const llm_build_cb & cb) {
const int64_t n_layer = hparams.n_layer;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int32_t n_rot = hparams.n_rot;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const float ext_factor = cparams.yarn_ext_factor;
const float attn_factor = cparams.yarn_attn_factor;
const float beta_fast = cparams.yarn_beta_fast;
const float beta_slow = cparams.yarn_beta_slow;
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx,
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
0),
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(graph, tmp);
}
}
static void llm_build_kv_store(
struct ggml_context * ctx,
const llama_hparams & hparams,
@ -4982,38 +4948,6 @@ static struct ggml_tensor * llm_build_kv(
return cur;
}
static llm_rope_type llm_get_rope_type(llm_arch arch) {
switch (arch) {
case LLM_ARCH_LLAMA: return LLM_ROPE;
case LLM_ARCH_FALCON: return LLM_ROPE_NEOX;
case LLM_ARCH_BAICHUAN: return LLM_ROPE;
case LLM_ARCH_GPT2: return LLM_ROPE_NONE;
case LLM_ARCH_GPTJ: return LLM_ROPE_NONE;
case LLM_ARCH_GPTNEOX: return LLM_ROPE_NONE;
case LLM_ARCH_MPT: return LLM_ROPE_NONE;
case LLM_ARCH_STARCODER: return LLM_ROPE;
case LLM_ARCH_PERSIMMON: return LLM_ROPE_NEOX;
case LLM_ARCH_REFACT: return LLM_ROPE_NONE;
case LLM_ARCH_BERT: return LLM_ROPE_NEOX;
case LLM_ARCH_NOMIC_BERT: return LLM_ROPE_NEOX;
case LLM_ARCH_BLOOM: return LLM_ROPE_NONE;
case LLM_ARCH_STABLELM: return LLM_ROPE_NEOX;
case LLM_ARCH_QWEN: return LLM_ROPE_NEOX;
case LLM_ARCH_QWEN2: return LLM_ROPE_NEOX;
case LLM_ARCH_PHI2: return LLM_ROPE_NEOX;
case LLM_ARCH_PLAMO: return LLM_ROPE;
case LLM_ARCH_CODESHELL: return LLM_ROPE;
case LLM_ARCH_ORION: return LLM_ROPE;
case LLM_ARCH_INTERNLM2: return LLM_ROPE;
case LLM_ARCH_MINICPM: return LLM_ROPE;
case LLM_ARCH_GEMMA: return LLM_ROPE;
case LLM_ARCH_UNKNOWN:
default:
GGML_ASSERT(false && "unknown architecture");
return LLM_ROPE_NONE;
}
}
struct llm_build_context {
const llama_model & model;
const llama_context & lctx;
@ -5024,6 +4958,7 @@ struct llm_build_context {
const int64_t n_embd;
const int64_t n_layer;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_head;
const int64_t n_head_kv;
@ -5048,9 +4983,8 @@ struct llm_build_context {
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
const uint32_t pooling_type;
const llm_rope_type rope_type;
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
const llm_build_cb & cb;
@ -5072,6 +5006,7 @@ struct llm_build_context {
kv_self (lctx.kv_self),
n_embd (hparams.n_embd),
n_layer (hparams.n_layer),
n_rot (hparams.n_rot),
n_ctx (cparams.n_ctx),
n_head (hparams.n_head),
n_head_kv (hparams.n_head_kv),
@ -5093,8 +5028,8 @@ struct llm_build_context {
n_kv (worst_case ? n_ctx : kv_self.n),
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_NONE),
rope_type (llm_get_rope_type(model.arch)),
pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE),
rope_type (hparams.rope_type),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
// all initializations should be done in init()
@ -5120,7 +5055,20 @@ struct llm_build_context {
struct ggml_cgraph * build_k_shift() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, rope_type, n_ctx, freq_base, freq_scale, cb);
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0),
lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(gf, tmp);
}
return gf;
}
@ -6063,12 +6011,12 @@ struct llm_build_context {
cur = inpL;
// pooling layer
if (pooling_type == LLAMA_POOLING_MEAN) {
if (pooling_type == LLAMA_POOLING_TYPE_MEAN) {
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
} else if (pooling_type == LLAMA_POOLING_CLS) {
} else if (pooling_type == LLAMA_POOLING_TYPE_CLS) {
cur = ggml_get_rows(ctx0, cur, inp_cls);
} else {
GGML_ASSERT(pooling_type == LLAMA_POOLING_NONE && "Invalid pooling type");
GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type");
}
cb(cur, "result_embd", -1);
@ -7521,6 +7469,7 @@ struct llm_build_context {
static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
llama_batch dummy;
dummy.n_tokens = 0;
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
@ -7735,7 +7684,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_MEAN) {
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@ -7763,7 +7712,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_CLS) {
if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@ -7784,7 +7733,7 @@ static void llama_graph_compute(
ggml_cgraph * gf,
int n_threads) {
#ifdef GGML_USE_MPI
const int64_t n_layer = hparams.n_layer;
const int64_t n_layer = lctx.hparams.n_layer;
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
#endif
@ -7902,9 +7851,7 @@ static int llama_decode_internal(
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
if (kv_self.has_shift) {
llama_kv_cache_apply_k_shift(&lctx);
}
llama_kv_cache_apply(&lctx);
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
@ -8042,9 +7989,9 @@ static int llama_decode_internal(
return 0;
}
void llama_kv_cache_apply_k_shift(struct llama_context * ctx) {
struct llama_context & lctx = *ctx;
static void llama_kv_cache_apply_internal(struct llama_context & lctx) {
// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
llama_set_k_shift(lctx);
{
@ -8054,7 +8001,7 @@ void llama_kv_cache_apply_k_shift(struct llama_context * ctx) {
}
{
auto & kv_self = ctx->kv_self;
auto & kv_self = lctx.kv_self;
kv_self.has_shift = false;
@ -8062,6 +8009,7 @@ void llama_kv_cache_apply_k_shift(struct llama_context * ctx) {
kv_self.cells[i].delta = 0;
}
}
}
}
//
@ -11338,7 +11286,7 @@ static int llama_apply_lora_from_file_internal(
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_LAYER,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
/*.progress_callback =*/ nullptr,
@ -11364,7 +11312,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_batch =*/ 512,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
@ -11552,16 +11500,16 @@ struct llama_context * llama_new_context_with_model(
cparams.cb_eval_user_data = params.cb_eval_user_data;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train;
}
if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
}
if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
if (params.seed == LLAMA_DEFAULT_SEED) {
@ -11595,8 +11543,8 @@ struct llama_context * llama_new_context_with_model(
}
#elif defined(GGML_USE_CUBLAS)
if (model->n_gpu_layers > 0) {
// with split_mode LLAMA_SPLIT_NONE or LLAMA_SPLIT_ROW, only the main GPU backend is used
if (model->split_mode == LLAMA_SPLIT_NONE || model->split_mode == LLAMA_SPLIT_ROW) {
// with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
@ -11605,7 +11553,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(backend);
} else {
// LLAMA_SPLIT_LAYER requires a backend for each GPU
// LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
ggml_backend_t backend = ggml_backend_cuda_init(device);
if (backend == nullptr) {
@ -11807,6 +11755,38 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_LLAMA: return LLAMA_ROPE_TYPE;
case LLM_ARCH_FALCON: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_BAICHUAN: return LLAMA_ROPE_TYPE;
case LLM_ARCH_GPT2: return LLAMA_ROPE_TYPE_NONE;
case LLM_ARCH_GPTJ: return LLAMA_ROPE_TYPE_NONE;
case LLM_ARCH_GPTNEOX: return LLAMA_ROPE_TYPE_NONE;
case LLM_ARCH_MPT: return LLAMA_ROPE_TYPE_NONE;
case LLM_ARCH_STARCODER: return LLAMA_ROPE_TYPE;
case LLM_ARCH_PERSIMMON: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_REFACT: return LLAMA_ROPE_TYPE_NONE;
case LLM_ARCH_BERT: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_NOMIC_BERT: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_BLOOM: return LLAMA_ROPE_TYPE_NONE;
case LLM_ARCH_STABLELM: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_PHI2: return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_PLAMO: return LLAMA_ROPE_TYPE;
case LLM_ARCH_CODESHELL: return LLAMA_ROPE_TYPE;
case LLM_ARCH_ORION: return LLAMA_ROPE_TYPE;
case LLM_ARCH_INTERNLM2: return LLAMA_ROPE_TYPE;
case LLM_ARCH_MINICPM: return LLAMA_ROPE_TYPE;
case LLM_ARCH_GEMMA: return LLAMA_ROPE_TYPE;
case LLM_ARCH_UNKNOWN:
default:
GGML_ASSERT(false && "unknown architecture");
return LLAMA_ROPE_TYPE_NONE;
}
}
int32_t llama_n_vocab(const struct llama_model * model) {
return model->vocab.id_to_token.size();
}
@ -12065,6 +12045,10 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
}
void llama_kv_cache_apply(struct llama_context * ctx) {
llama_kv_cache_apply_internal(*ctx);
}
// Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) {

41
llama.h
View file

@ -64,6 +64,13 @@ extern "C" {
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
};
enum llama_rope_type {
LLAMA_ROPE_TYPE_NONE = -1,
LLAMA_ROPE_TYPE = 0,
LLAMA_ROPE_TYPE_NEOX = 2,
LLAMA_ROPE_TYPE_GLM = 4,
};
enum llama_token_type {
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
@ -107,23 +114,23 @@ extern "C" {
};
enum llama_rope_scaling_type {
LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_NONE = 0,
LLAMA_ROPE_SCALING_LINEAR = 1,
LLAMA_ROPE_SCALING_YARN = 2,
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1,
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
LLAMA_ROPE_SCALING_TYPE_YARN = 2,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
};
enum llama_pooling_type {
LLAMA_POOLING_NONE = 0,
LLAMA_POOLING_MEAN = 1,
LLAMA_POOLING_CLS = 2,
LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
};
enum llama_split_mode {
LLAMA_SPLIT_NONE = 0, // single GPU
LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_ROW = 2, // split rows across GPUs
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};
typedef struct llama_token_data {
@ -358,6 +365,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@ -512,7 +520,9 @@ extern "C" {
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_apply()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_add(
@ -523,7 +533,9 @@ extern "C" {
llama_pos delta);
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_apply()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div(
@ -533,7 +545,8 @@ extern "C" {
llama_pos p1,
int d);
LLAMA_API void llama_kv_cache_apply_k_shift(struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts) to the KV data
LLAMA_API void llama_kv_cache_apply(struct llama_context * ctx);
//
// State / sessions