update: work for bot mpt and awqmpt

This commit is contained in:
Trần Đức Nam 2023-12-19 23:25:00 +07:00
parent 8fece75e35
commit 8177ad4e37
5 changed files with 48 additions and 21 deletions

View file

@ -149,6 +149,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.seed = std::stoul(argv[i]); params.seed = std::stoul(argv[i]);
} else if (arg == "-awq" || arg == "--use-awq") {
params.use_awq = true;
} else if (arg == "-t" || arg == "--threads") { } else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -804,6 +806,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" (can be specified more than once for multiple prompts).\n"); printf(" (can be specified more than once for multiple prompts).\n");
printf(" --color colorise output to distinguish prompt and user input from generations\n"); printf(" --color colorise output to distinguish prompt and user input from generations\n");
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
printf(" -awq SEED, -use-awq Using AWQ quantization model in inferences\n");
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N\n"); printf(" -tb N, --threads-batch N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
@ -1013,6 +1016,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.tensor_split = params.tensor_split; mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap; mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock; mparams.use_mlock = params.use_mlock;
mparams.use_awq = params.use_awq;
if (params.kv_overrides.empty()) { if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL; mparams.kv_overrides = NULL;
} else { } else {
@ -1096,13 +1100,11 @@ void llama_batch_add(
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) { std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params); auto mparams = llama_model_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) { if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr);
} }
auto cparams = llama_context_params_from_gpt_params(params); auto cparams = llama_context_params_from_gpt_params(params);
llama_context * lctx = llama_new_context_with_model(model, cparams); llama_context * lctx = llama_new_context_with_model(model, cparams);

View file

@ -125,6 +125,7 @@ struct gpt_params {
bool infill = false; // use infill mode bool infill = false; // use infill mode
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading bool no_kv_offload = false; // disable KV offloading
bool use_awq = false; // use AWQ quantization infer
std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V std::string cache_type_v = "f16"; // KV cache data type for the V

View file

@ -46,7 +46,7 @@ class Model:
self.part_names = self._get_part_names() self.part_names = self._get_part_names()
self.hparams = Model.load_hparams(self.dir_model) self.hparams = Model.load_hparams(self.dir_model)
self.model_arch = self._get_model_architecture() self.model_arch = self._get_model_architecture()
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess) self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
def set_vocab(self): def set_vocab(self):
self._set_vocab_gpt2() self._set_vocab_gpt2()
@ -59,7 +59,7 @@ class Model:
from safetensors import safe_open from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else: else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", weights_only=True))
with ctx as model_part: with ctx as model_part:
for name in model_part.keys(): for name in model_part.keys():
@ -444,7 +444,7 @@ class MPTModel(Model):
# map tensor names # map tensor names
if "scales" in name: if "scales" in name:
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales")) new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
new_name = new_name + ".scales" new_name = new_name.replace("scales", "act.scales")
else: else:
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None: if new_name is None:
@ -1001,6 +1001,7 @@ dir_model = args.model
if args.awq_path: if args.awq_path:
from awqpy.apply_awq import add_scale_weights from awqpy.apply_awq import add_scale_weights
tmp_model_path = args.model / "weighted_model" tmp_model_path = args.model / "weighted_model"
dir_model = tmp_model_path
if tmp_model_path.is_dir(): if tmp_model_path.is_dir():
print(f"{tmp_model_path} exists as a weighted model.") print(f"{tmp_model_path} exists as a weighted model.")
else: else:
@ -1008,7 +1009,6 @@ if args.awq_path:
print("Saving new weighted model ...") print("Saving new weighted model ...")
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path)) add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
print(f"Saved weighted model at {tmp_model_path}.") print(f"Saved weighted model at {tmp_model_path}.")
dir_model = tmp_model_path
if not dir_model.is_dir(): if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file=sys.stderr) print(f'Error: {args.model} is not a directory', file=sys.stderr)
@ -1029,6 +1029,7 @@ print(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model) hparams = Model.load_hparams(dir_model)
with torch.inference_mode(): with torch.inference_mode():
model_class = Model.from_model_architecture(hparams["architectures"][0]) model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian) model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)

View file

@ -1178,6 +1178,7 @@ struct llama_hparams {
float f_clamp_kqv; float f_clamp_kqv;
float f_max_alibi_bias; float f_max_alibi_bias;
bool use_awq;
bool operator!=(const llama_hparams & other) const { bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true; if (this->vocab_only != other.vocab_only) return true;
@ -3379,7 +3380,6 @@ static void llm_load_tensors(
case LLM_ARCH_MPT: case LLM_ARCH_MPT:
{ {
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output // output
{ {
ggml_backend_type backend_norm; ggml_backend_type backend_norm;
@ -3423,18 +3423,31 @@ static void llm_load_tensors(
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.ffn_act = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, backend); if (model.hparams.use_awq) {
layer.ffn_act = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, backend);
}
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) { if (backend == GGML_BACKEND_GPU) {
vram_weights += if (model.hparams.use_awq) {
vram_weights +=
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm) +
ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wqkv) +
ggml_nbytes(layer.wo) + ggml_nbytes(layer.wo) +
ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm) +
ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down) +
ggml_nbytes(layer.ffn_act) + ggml_nbytes(layer.ffn_act) +
ggml_nbytes(layer.ffn_up); ggml_nbytes(layer.ffn_up);
}
else {
vram_weights +=
ggml_nbytes(layer.attn_norm) +
ggml_nbytes(layer.wqkv) +
ggml_nbytes(layer.wo) +
ggml_nbytes(layer.ffn_norm) +
ggml_nbytes(layer.ffn_down) +
ggml_nbytes(layer.ffn_up);
}
} }
} }
} break; } break;
@ -3634,7 +3647,7 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con
llama_model_loader ml(fname, params.use_mmap, params.kv_overrides); llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
model.hparams.vocab_only = params.vocab_only; model.hparams.vocab_only = params.vocab_only;
model.hparams.use_awq = params.use_awq;
llm_load_arch (ml, model); llm_load_arch (ml, model);
llm_load_hparams(ml, model); llm_load_hparams(ml, model);
llm_load_vocab (ml, model); llm_load_vocab (ml, model);
@ -5119,13 +5132,23 @@ struct llm_build_context {
NULL, NULL,
LLM_NORM, cb, il); LLM_NORM, cb, il);
cb(cur, "ffn_norm", il); cb(cur, "ffn_norm", il);
if (hparams.use_awq) {
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
NULL, NULL,
model.layers[il].ffn_down, NULL,
model.layers[il].ffn_act,
LLM_FFN_GELU_ACT, LLM_FFN_SEQ, cb, il);
cur = llm_build_ffn(ctx0, cur, }
model.layers[il].ffn_up, NULL, else {
NULL, NULL, cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_act, NULL, NULL,
LLM_FFN_GELU_ACT, LLM_FFN_SEQ, cb, il); model.layers[il].ffn_down, NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
}
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -8841,6 +8864,7 @@ struct llama_model_params llama_model_default_params() {
/*.progress_callback_user_data =*/ nullptr, /*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr, /*.kv_overrides =*/ nullptr,
/*.vocab_only =*/ false, /*.vocab_only =*/ false,
/*.use_awq =*/ false,
/*.use_mmap =*/ true, /*.use_mmap =*/ true,
/*.use_mlock =*/ false, /*.use_mlock =*/ false,
}; };
@ -8936,9 +8960,7 @@ struct llama_model * llama_load_model_from_file(
const char * path_model, const char * path_model,
struct llama_model_params params) { struct llama_model_params params) {
ggml_time_init(); ggml_time_init();
llama_model * model = new llama_model; llama_model * model = new llama_model;
unsigned cur_percentage = 0; unsigned cur_percentage = 0;
if (params.progress_callback == NULL) { if (params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage; params.progress_callback_user_data = &cur_percentage;

View file

@ -192,6 +192,7 @@ extern "C" {
bool vocab_only; // only load the vocabulary, no weights bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM bool use_mlock; // force system to keep model in RAM
bool use_awq; // whether to use awq quantization
}; };
struct llama_context_params { struct llama_context_params {