This commit is contained in:
Julia 2024-06-18 09:43:00 +01:00 committed by GitHub
commit 3819a8e7fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 237 additions and 13 deletions

View file

@ -31,17 +31,17 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization" , },
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", },
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", },
{ "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", },
{ "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", },
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
@ -49,8 +49,9 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
{ "CUSTOM", LLAMA_FTYPE_CUSTOM, "[:filename] Custom quant config (quant.cfg if not specified", },
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
};
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
@ -58,12 +59,33 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count";
static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out, std::string & custom_cfg_filename_out) {
std::string ftype_str;
for (auto ch : ftype_str_in) {
ftype_str.push_back(std::toupper(ch));
}
if (ftype_str.find("CUSTOM:") == 0) {
// custom quant mix
ftype = LLAMA_FTYPE_CUSTOM;
ftype_str_out = "CUSTOM";
if (ftype_str.length() > 7) {
// extract config filename (take from ftype_str_in to get original casing)
std::string custom_cfg = ftype_str_in.substr(7);
custom_cfg_filename_out = custom_cfg;
} else {
return false;
}
return true;
} else if (ftype_str.find("CUSTOM") == 0) {
// custom quant mix with default config
ftype = LLAMA_FTYPE_CUSTOM;
ftype_str_out = "CUSTOM";
custom_cfg_filename_out = "quant.cfg";
return true;
}
for (auto & it : QUANT_OPTIONS) {
if (it.name == ftype_str) {
ftype = it.ftype;
@ -224,13 +246,119 @@ static ggml_type parse_ggml_type(const char * arg) {
for (int j = 0; j < GGML_TYPE_COUNT; ++j) {
auto type = ggml_type(j);
const auto * name = ggml_type_name(type);
if (name && strcmp(arg, name) == 0) {
if (name && strcasecmp(arg, name) == 0) {
result = type; break;
}
}
return result;
}
static bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
const char* sep = strchr(data, '=');
if (sep == nullptr || sep - data >= 128) {
fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
return false;
}
llama_model_kv_override kvo;
std::strncpy(kvo.key, data, sep - data);
kvo.key[sep - data] = 0;
sep++;
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
} else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false;
} else {
fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
return false;
}
} else {
fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
return false;
}
overrides.emplace_back(std::move(kvo));
return true;
}
static bool read_custom_quant_config(const std::string& filename, llama_model_quantize_ftype_override& override) {
std::ifstream file(filename);
std::string line;
std::vector<std::string> names;
std::vector<ggml_type> types;
printf("reading custom quantization mix from %s:\n", filename.c_str());
if (!file.is_open()) {
fprintf(stderr, "%s: failed to open file: '%s'\n", __func__, filename.c_str());
return false;
}
while (getline(file, line)) {
// skip empty lines and comments
if (line.empty() || line[0] == '#') continue;
// default file type
if (line.find("ftype=") == 0) {
std::string ftype_str = line.substr(6);
std::string ftype_name;
std::string custom_quant_config_filename;
llama_ftype ftype;
if(!try_parse_ftype(ftype_str, ftype, ftype_name, custom_quant_config_filename)) {
fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, ftype_str.c_str());
file.close();
return false;
}
override.default_ftype = static_cast<llama_ftype>(ftype);
printf(" default ftype = %i (%s)\n", ftype, ftype_name.c_str());
continue;
}
// tensor overrides
size_t pos = line.find('=');
if (pos != std::string::npos) {
std::string tensor_name = line.substr(0, pos);
std::string type_name = line.substr(pos + 1);
ggml_type type = parse_ggml_type(type_name.c_str());
if(type < 0 || type >= GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, type_name.c_str());
file.close();
return false;
}
names.push_back(tensor_name);
types.push_back(static_cast<ggml_type>(type));
printf(" %s = %i (%s)\n", tensor_name.c_str(), type, type_name.c_str());
}
}
printf("\n");
// allocate memory for names and types
override.names = new const char*[names.size()];
override.types = new ggml_type[types.size()];
override.count = names.size();
for (size_t i = 0; i < names.size(); ++i) {
override.names[i] = strdup(names[i].c_str());
override.types[i] = types[i];
}
file.close();
return true;
}
int main(int argc, char ** argv) {
if (argc < 3) {
usage(argv[0]);
@ -349,10 +477,11 @@ int main(int argc, char ** argv) {
const std::string fname_inp = argv[arg_idx];
arg_idx++;
std::string fname_out;
std::string custom_quant_config_filename;
std::string ftype_str;
std::string suffix = ".gguf";
if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str, custom_quant_config_filename)) {
std::string fpath;
const size_t pos = fname_inp.find_last_of("/\\");
if (pos != std::string::npos) {
@ -379,13 +508,23 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: missing ftype\n", __func__);
return 1;
}
if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str, custom_quant_config_filename)) {
fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
return 1;
}
if (ftype_str == "COPY") {
params.only_copy = true;
}
if (ftype_str == "CUSTOM") {
params.override_ftype = new llama_model_quantize_ftype_override;
if(!read_custom_quant_config(custom_quant_config_filename, *params.override_ftype)) {
return 1;
}
}
arg_idx++;
}

View file

@ -4096,6 +4096,9 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
// Custom quantization scheme
case LLAMA_FTYPE_CUSTOM: return "CUSTOM";
default: return "unknown, may not work";
}
}
@ -15383,11 +15386,35 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
return new_size;
}
static bool match_string(const std::string& str, const std::string& pattern, uint32_t string_index = 0, uint32_t pattern_index = 0) {
// if both index pointers reach the end of str and pattern respectively
if (string_index == str.size() && pattern_index == pattern.size()) {
return true;
}
// if pattern character is '*', it can match with any sequence of characters.
if (pattern_index < pattern.size() && pattern[pattern_index] == '*') {
// move pattern index by 1 and match rest, or keep string index same and move pattern index
return match_string(str, pattern, string_index, pattern_index + 1) || (string_index < str.size() && match_string(str, pattern, string_index + 1, pattern_index));
}
// if current characters match or pattern character is '?'
if (string_index < str.size() && pattern_index < pattern.size() && (str[string_index] == pattern[pattern_index] || pattern[pattern_index] == '?')) {
return match_string(str, pattern, string_index + 1, pattern_index + 1);
}
return false;
}
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
ggml_type default_type;
llama_ftype ftype = params->ftype;
switch (params->ftype) {
llama_ftype ftype =
params->override_ftype
? params->override_ftype->default_ftype
: params->ftype;
switch (ftype) {
case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
@ -15478,7 +15505,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// copy the KV pairs from the input file
gguf_set_kv (ctx_out, ml.meta);
gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
gguf_set_val_u32(ctx_out, "general.file_type", ftype);
gguf_set_val_u32(ctx_out, "general.file_type", params->ftype);
// Remove split metadata
gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
@ -15666,6 +15693,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
new_type = params->output_tensor_type;
}
// look up tensor name in type override map, if not found use default
// type as determined by the ftype.
if(params->override_ftype) {
for (uint32_t i = 0; i < params->override_ftype->count; ++i) {
if (match_string(tensor->name, params->override_ftype->names[i])) {
// printf("\n -----> %s, %s\n", params->override_ftype->names[i], tensor->name);
new_type = params->override_ftype->types[i];
break;
}
}
}
// If we've decided to quantize to the same type the tensor is already
// in then there's nothing to do.
quantize = tensor->type != new_type;
@ -16131,7 +16170,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
/*.pure =*/ false,
/*.keep_split =*/ false,
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.override_ftype =*/ nullptr
};
return result;

11
llama.h
View file

@ -157,6 +157,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_CUSTOM = 33, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
@ -324,6 +325,13 @@ extern "C" {
void * abort_callback_data;
};
typedef struct llama_model_quantize_ftype_override {
enum llama_ftype default_ftype; // default type if not overriden
uint32_t count; // number of overrides
const char ** names; // tensor names
enum ggml_type * types; // tensor type override
} llama_model_quantize_custom_ftype;
// model quantization parameters
typedef struct llama_model_quantize_params {
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
@ -332,11 +340,12 @@ extern "C" {
enum ggml_type token_embedding_type; // itoken embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool only_copy; // only copy tensors - ftype, override_ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
struct llama_model_quantize_ftype_override * override_ftype; // custom quantization scheme
} llama_model_quantize_params;
// grammar types

36
quant.cfg Normal file
View file

@ -0,0 +1,36 @@
# Defines the default ftype (the quantization mix code,
# that you pass to quantize if you're not using custom mix).
# tensors that are not overriden below will be quantized
# according to this mix.
#
# Must be one of
# Q4_0, Q4_1, Q5_0, Q5_1, IQ2_XXS, IQ2_XS, IQ2_S, IQ2_M,
# IQ1_S, IQ1_M, Q2_K, Q2_K_S, IQ3_XXS, IQ3_S, IQ3_M, Q3_K,
# IQ3_XS, Q3_K_S, Q3_K_M, Q3_K_L, IQ4_NL, IQ4_XS, Q4_K,
# Q4_K_S, Q4_K_M, Q5_K, Q5_K_S, Q5_K_M, Q6_K, Q8_0, F16
ftype=Q6_K
# Defines overrides for tensors with names matching a given
# string. Filters are processed in order given, the first
# matching will be used.
#
# Wildcards are allowed:
# ? single character
# * multiple characters
#
# Type must be one of
# F16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, Q8_1, Q2_K, Q3_K,
# Q4_K, Q5_K, Q6_K, Q8_K, IQ2_XXS, IQ2_XS, IQ3_XXS,
# IQ1_S, IQ4_NL, IQ3_S, IQ2_S, IQ4_XS, IQ1_M
blk.10.ffn_up.weight=Q5_K
blk.1?.ffn_up.weight=Q4_K
blk.23.*=Q2_K
blk.24.*=Q2_K
blk.25.*=Q2_K
blk.2?.ffn_up.weight=Q4_K
*_gate*=Q4_K
*.attn*=IQ4_XS
*_down*=IQ3_S
output.weight=Q5_K