changed llama.cpp (build_phi3 to load bias for lm.head); fixed dumb eos token issues

This commit is contained in:
Yutong Dai 2024-09-19 01:13:25 +00:00
parent 30b751ef06
commit 279308c74a
8 changed files with 412 additions and 58 deletions

View file

@ -263,6 +263,9 @@ class Model:
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue
# added for xgenmm
if name.endswith((".additional_embedding.weight", ".additional_fc.bias", "additional_fc.weight")):
continue
old_dtype = data_torch.dtype
@ -2069,6 +2072,8 @@ class Phi3MiniModel(Model):
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.special_token_ids['eos'] = 32007
print("YD: set special_vocab.special_token_ids['eos'] = 32007")
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):

View file

@ -14,10 +14,14 @@ conda activate xgenmm-flamingo
# step 3: convert llm to gguf
# https://github.com/ggerganov/llama.cpp/discussions/7927
HF_TOKEN=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
LLM_PATH=/export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/llm
# LLM_OUTPUT_FILE=/export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf/phi3_.gguf
cd ../../
# HF_TOKEN=<PUT YOUR TOKEN HERE>
# downloads the tokenizer models of the specified models from Huggingface; generates the get_vocab_base_pre() function for convert_hf_to_gguf.py
cd ../..
# python convert_hf_to_gguf_update.py $HF_TOKEN
python convert_hf_to_gguf.py $LLM_PATH
LLM_PATH=/export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/llm
outtype=f32
LLM_OUTPUT_FILE=/export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf/phi3_mini_4k_instruct_$outtype.gguf
echo $LLM_OUTPUT_FILE
python convert_hf_to_gguf.py $LLM_PATH --outfile $LLM_OUTPUT_FILE --outtype $outtype

View file

@ -2,35 +2,15 @@
make xgenmm-cli
# ./xgenmm-cli -m /export/share/tawalgaonkar/llama.cpp/models/llm/xgenmm-phi-3-llm-Q4.gguf \
# ./xgenmm-cli --model /export/share/tawalgaonkar/llama.cpp/models/llm/xgenmm-phi-3-llm-Q4.gguf \
# --mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \
# -c 4096 --temp 0.01 --repeat-penalty 1.05 \
# --image /export/home/llama.cpp/examples/xgenmm/imgs/image-1d100e9-1.jpg \
# -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. \nThe assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image>\nWhat is the color of this notebook?<|end|>\n<|assistant|>\n"
# ./xgenmm-cli -m /export/share/tawalgaonkar/llama.cpp/models/llm/xgenmm-phi-3-llm-Q4.gguf \
# --mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \
# -c 4096 --temp 0 --num_beams 1 \
# --image /export/home/on-device-mm/notebooks/open-flamingo/imgs/receipt.jpg \
# -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. \nThe assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image> Describe this image.<|end|>\n<|assistant|>\n"
# ./xgenmm-cli -m /export/share/tawalgaonkar/llama.cpp/models/llm/xgenmm-phi-3-llm-Q4.gguf \
# --mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \
# -c 4096 --temp 0.01 --repeat-penalty 1.05 \
# --image /export/home/llama.cpp/examples/xgenmm/imgs/image-1d100e9.jpg\
# -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. \nThe assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image>\n How many objects are there in this image?<|end|>\n<|assistant|>\n"
./xgenmm-cli --model /export/share/tawalgaonkar/llama.cpp/models/llm/xgenmm-phi-3-llm-Q4.gguf \
--mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \
--image /export/home/llama.cpp/examples/xgenmm/imgs/receipt.jpg\
--prompt "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image>\n Describe this image.<|end|>\n<|assistant|>\n" \
--seed 42 --ctx-size 4096 --predict 1024 \
--temp 0 --verbose-prompt
#
# --image /export/home/llama.cpp/examples/xgenmm/imgs/receipt.jpg\
# --prompt "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image>\n Describe this image.<|end|>\n<|assistant|>\n" \
# --seed 42 --ctx-size 4096 --predict 1024 \
# --temp 0 --verbose-prompt
# ./xgenmm-cli --model /export/share/tawalgaonkar/llama.cpp/models/llm/xgenmm-phi-3-llm-Q4.gguf \
@ -39,3 +19,39 @@ make xgenmm-cli
# --prompt "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image>\n What is the address of this restirant?<|end|>\n<|assistant|>\n" \
# --seed 42 --ctx-size 4096 --predict 1024 \
# --temp 0 --verbose-prompt
# ./xgenmm-cli --model /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf/phi3_mini_4k_instruct_f16.gguf \
# --mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \
# --image /export/home/llama.cpp/examples/xgenmm/imgs/receipt.jpg\
# --prompt "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image>\n Describe this image.<|end|>\n<|assistant|>\n" \
# --seed 42 --ctx-size 4096 --predict 1024 \
# --temp 0 --verbose-prompt
# ./xgenmm-cli --model /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf/phi3_mini_4k_instruct_f32.gguf \
# --mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \
# --image /export/home/llama.cpp/examples/xgenmm/imgs/receipt.jpg\
# --prompt "<unk><s></s><|endoftext|><|assistant|><pad><|end|><image><image placeholder><|endofchunk|>" \
# --seed 42 --ctx-size 4096 --predict 1024 \
# --temp 0 --verbose-prompt
# Q="What is the address of this resturant?"
# Q="Is this dine in or dine out receipt?"
# Q="What is the total amount paid?"
# Q="What is card holder's name?"
# Q="What is the transaction date?"
# Q="What is the phone number of this resturant?"
Q="Who is the attendant?"
# Q="Who is the cashier?"
# Q="Briefly describe this image."
prompt="<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n<|user|>\n<image>\n $Q<|end|>\n<|assistant|>\n"
echo $prompt
model=/export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf/phi3_mini_4k_instruct_f32.gguf
# model=/export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf/phi3_mini_4k_instruct_f16.gguf
./xgenmm-cli --model $model\
--mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \
--image /export/home/llama.cpp/examples/xgenmm/imgs/receipt.jpg\
--prompt "$prompt" \
--seed 42 --ctx-size 4096 --predict 1024 \
--temp 0.8 --verbose-prompt --color --ubatch-size 1280

View file

@ -12,6 +12,8 @@
static bool eval_tokens(struct llama_context *ctx_llama, std::vector<llama_token> tokens, int n_batch, int *n_past)
{
int N = (int)tokens.size();
// printf("token.size(): %d\n", N);
// printf("n_batch: %d\n", n_batch);
for (int i = 0; i < N; i += n_batch)
{
int n_eval = (int)tokens.size() - i;
@ -19,6 +21,7 @@ static bool eval_tokens(struct llama_context *ctx_llama, std::vector<llama_token
{
n_eval = n_batch;
}
// printf("n_eval: %d, n_past: %d\n", n_eval, *n_past);
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0)))
{
LOG_TEE("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
@ -41,12 +44,14 @@ static bool eval_string(struct llama_context *ctx_llama, const char *str, int n_
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true);
printf("prompt: %s", str);
for (auto token : embd_inp){
printf("%6d, ", token);
}
printf("!!prompt to eval!!: %s", str);
printf("----------------------\n");
// for (auto token : embd_inp){
// printf("%6d, ", token);
// }
printf("\n");
return eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
return true;
}
static const char *sample(struct llama_sampling_context *ctx_sampling, struct llama_context *ctx_llama, int *n_past)
@ -56,7 +61,7 @@ static const char *sample(struct llama_sampling_context *ctx_sampling, struct ll
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id))
{
ret = "</s>";
ret = "<|end|>";
}
else
{
@ -247,6 +252,12 @@ static void process_prompt(struct llava_context *ctx_llava, struct llava_image_e
}
}
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
// image_embed
// struct llava_image_embed
// {
// float *embed;
// int n_image_pos;
// };
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);

View file

@ -1005,7 +1005,7 @@ bool llava_eval_image_embed(llama_context *ctx_llama, const struct llava_image_e
int *n_past)
{
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
// printf("n_embd: %d\n", n_embd);
for (int i = 0; i < image_embed->n_image_pos; i += n_batch)
{
int n_eval = image_embed->n_image_pos - i;
@ -1013,17 +1013,18 @@ bool llava_eval_image_embed(llama_context *ctx_llama, const struct llava_image_e
{
n_eval = n_batch;
}
// printf("(llava_eval_image_embed) n_eval: %d\n", n_eval);
llama_batch batch = {
int32_t(n_eval),
nullptr,
(image_embed->embed + i * n_embd),
/* n_tokens */ int32_t(n_eval),
/* llama_token */ nullptr,
/* embed */ (image_embed->embed + i * n_embd),
nullptr,
nullptr,
nullptr,
nullptr,
*n_past,
1,
0,
/* all_pos_0 */ *n_past,
/* all_pos_1 */ 1,
/* all_seq_id */ 0,
};
if (llama_decode(ctx_llama, batch))
{
@ -1031,6 +1032,8 @@ bool llava_eval_image_embed(llama_context *ctx_llama, const struct llava_image_e
return false;
}
*n_past += n_eval;
// printf("exit from llava_eval_image_embed\n");
// exit(-1);
}
return true;
}

View file

@ -70,8 +70,6 @@ if __name__ == "__main__":
tokenizer_path=cfg.lm_path,
model_family=cfg.model_family,
**additional_kwargs)
print(model)
exit(1)
model.load_state_dict(ckpt, strict=True)
end = time.time()
print(f"🟢 time used: [{end-start:.3f} s] | Done with instaiating the model.")
@ -95,6 +93,7 @@ if __name__ == "__main__":
# put the tokenizer in the same dir as the lang model
tokenizer.save_pretrained(f"{save_dir}/llm")
print("❗❗❗ Please also download tokenizer.json mannually from https://huggingface.co/Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5/resolve/main/tokenizer.json and put it in the same dir as the lang model.")
end = time.time()
print(f"🟢 time used: [{end-start:.3f} s]")

View file

@ -95,7 +95,7 @@
#include <thread>
#include <type_traits>
#include <unordered_map>
#include <iostream>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
@ -107,6 +107,283 @@
//
// helpers
//
void print_tensor_cpp(ggml_tensor *tensor, const char *name = "", int verbosity = 0)
{
if (tensor->ne[2] == 1)
{
printf("---> %s: (%ld, %ld)\n", name, tensor->ne[0], tensor->ne[1]);
}
else if (ggml_is_3d(tensor))
{
printf("---> %s: (%ld, %ld, %ld)\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2]);
}
else
{
printf("---> %s: (%ld, %ld, %ld, %ld)\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
}
if (verbosity == 1)
{
printf("*********************************************************************\n");
if (tensor->ne[2] == 1)
{
const float *mat = (float *)tensor->data;
// check if mat is NULL
if (mat == NULL)
{
printf("mat is NULL\n");
return;
}
int dim0 = tensor->ne[1];
int dim1 = tensor->ne[0];
if (dim0 < 6 && dim1 < 6)
{
for (int i = 0; i < dim0; i++)
{
for (int j = 0; j < dim1; j++)
{
printf("%+.6f ", mat[i * dim1 + j]);
}
printf("\n");
}
printf("\n");
}
else
{
for (int i = 0; i < std::min(dim0, 3); i++)
{
for (int j = 0; j < std::min(dim1, 3); j++)
{
printf("%+.6f ", mat[i * dim1 + j]);
}
printf("... ");
for (int j = dim1 - 3; j < dim1; j++)
{
printf("%+.6f ", mat[i * dim1 + j]);
}
printf("\n");
}
if (dim0 > 3)
{
printf("...................... omit ......................\n");
for (int i = dim0 - 3; i < dim0; i++)
{
for (int j = 0; j < std::min(dim1, 3); j++)
{
printf("%+.6f ", mat[i * dim1 + j]);
}
printf("... ");
for (int j = dim1 - 3; j < dim1; j++)
{
printf("%+.6f ", mat[i * dim1 + j]);
}
printf("\n");
}
}
}
}
else if (ggml_is_3d(tensor))
{
const float *data = (float *)tensor->data;
int dim0 = tensor->ne[2];
int dim1 = tensor->ne[1];
int dim2 = tensor->ne[0];
if (dim0 < 6 && dim1 < 6 && dim2 < 6)
{
for (int i = 0; i < dim0; i++)
{
printf("dim0 = %d\n", i);
for (int j = 0; j < dim1; j++)
{
for (int k = 0; k < dim2; k++)
{
printf("%+.6f ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("\n");
}
printf("\n");
}
printf("\n");
}
else
{
for (int i = 0; i < std::min(dim0, 3); i++)
{
printf("dim0 = %d\n", i);
for (int j = 0; j < std::min(dim1, 3); j++)
{
for (int k = 0; k < std::min(dim2, 3); k++)
{
printf("%+.6f ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("... ");
for (int k = dim2 - 3; k < dim2; k++)
{
printf("%+.6f ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("\n");
}
printf("........................\n");
for (int j = dim1 - 3; j < dim1; j++)
{
for (int k = 0; k < std::min(dim2, 3); k++)
{
printf("%+.6f ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("... ");
for (int k = dim2 - 3; k < dim2; k++)
{
printf("%+.6f ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("\n");
}
printf("---------------------------------------------------\n");
}
printf("\n");
}
}
}
printf("*********************************************************************\n");
printf("\n");
}
void print_tensor_cpp_int(ggml_tensor *tensor, const char *name = "", int verbosity = 0)
{
if (tensor->ne[2] == 1)
{
printf("---> %s: (%ld, %ld)\n", name, tensor->ne[0], tensor->ne[1]);
}
else if (ggml_is_3d(tensor))
{
printf("---> %s: (%ld, %ld, %ld)\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2]);
}
else
{
printf("---> %s: (%ld, %ld, %ld, %ld)\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
}
if (verbosity == 1)
{
printf("*********************************************************************\n");
if (tensor->ne[2] == 1)
{
const int *mat = (int *)tensor->data;
// check if mat is NULL
if (mat == NULL)
{
printf("mat is NULL\n");
return;
}
int dim0 = tensor->ne[1];
int dim1 = tensor->ne[0];
if (dim0 < 6 && dim1 < 6)
{
for (int i = 0; i < dim0; i++)
{
for (int j = 0; j < dim1; j++)
{
printf("%d ", mat[i * dim1 + j]);
}
printf("\n");
}
printf("\n");
}
else
{
for (int i = 0; i < std::min(dim0, 3); i++)
{
for (int j = 0; j < std::min(dim1, 3); j++)
{
std::cout << mat[i * dim1 + j];
}
printf("... ");
for (int j = dim1 - 3; j < dim1; j++)
{
std::cout << mat[i * dim1 + j];
}
printf("\n");
}
if (dim0 > 3)
{
printf("...................... omit ......................\n");
for (int i = dim0 - 3; i < dim0; i++)
{
for (int j = 0; j < std::min(dim1, 3); j++)
{
std::cout << mat[i * dim1 + j];
}
printf("... ");
for (int j = dim1 - 3; j < dim1; j++)
{
std::cout << mat[i * dim1 + j];
}
printf("\n");
}
}
}
}
else if (ggml_is_3d(tensor))
{
const float *data = (float *)tensor->data;
int dim0 = tensor->ne[2];
int dim1 = tensor->ne[1];
int dim2 = tensor->ne[0];
if (dim0 < 6 && dim1 < 6 && dim2 < 6)
{
for (int i = 0; i < dim0; i++)
{
printf("dim0 = %d\n", i);
for (int j = 0; j < dim1; j++)
{
for (int k = 0; k < dim2; k++)
{
printf("%d ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("\n");
}
printf("\n");
}
printf("\n");
}
else
{
for (int i = 0; i < std::min(dim0, 3); i++)
{
printf("dim0 = %d\n", i);
for (int j = 0; j < std::min(dim1, 3); j++)
{
for (int k = 0; k < std::min(dim2, 3); k++)
{
printf("%d ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("... ");
for (int k = dim2 - 3; k < dim2; k++)
{
printf("%d ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("\n");
}
printf("........................\n");
for (int j = dim1 - 3; j < dim1; j++)
{
for (int k = 0; k < std::min(dim2, 3); k++)
{
printf("%d ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("... ");
for (int k = dim2 - 3; k < dim2; k++)
{
printf("%d ", data[i * dim1 * dim2 + j * dim2 + k]);
}
printf("\n");
}
printf("---------------------------------------------------\n");
}
printf("\n");
}
}
}
printf("*********************************************************************\n");
printf("\n");
}
// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
@ -4361,6 +4638,8 @@ struct llama_model_loader {
} else {
ggml_backend_tensor_set(cur, data, 0, n_size);
}
// printf("Loading tensor %s | dtype %d\n", ggml_get_name(cur), cur->type);
// print_tensor_cpp(cur, ggml_get_name(cur), 1);
} else {
GGML_ASSERT(weight->idx < files.size());
const auto & file = files.at(weight->idx);
@ -5726,6 +6005,14 @@ static void llm_load_vocab(
)
) {
vocab.special_eot_id = t.second;
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0)
{
LLAMA_LOG_WARN(
"%s: control-looking token: '%s' was not control-type; this is probably a bug in the "
"model. its type will be overridden\n",
__func__, t.first.c_str());
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
}
break;
}
}
@ -6814,13 +7101,14 @@ static bool llm_load_tensors(
case LLM_ARCH_PHI3:
{
const int64_t n_embd_head = n_embd / n_head;
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab });
model.output_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab});
}
for (int i = 0; i < n_layer; ++i) {
@ -7910,6 +8198,13 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
return -1;
}
// print_tensor_cpp(model.output_b, "output_b", 1);
// print_tensor_cpp(model.tok_embd, "(in llama_model_load) tok_embd", 1);
// print_tensor_cpp(model.output_norm, "(in llama_model_load) output_norm", 1);
// auto layer = model.layers[0];
// print_tensor_cpp(layer.wo, "(in llama_model_load) layer.wo", 1);
// print_tensor_cpp(model.output, "output", 1);
// printf("successfully loaded model\n");
return 0;
}
@ -10969,13 +11264,12 @@ struct llm_build_context {
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
// n_layer = 2;
for (int il = 0; il < n_layer; ++il) {
auto residual = inpL;
@ -11078,6 +11372,9 @@ struct llm_build_context {
cb(cur, "result_norm", -1);
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output_no_bias", -1);
cur = ggml_add(ctx0, cur, model.output_b);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
@ -14935,6 +15232,7 @@ static int llama_decode_internal(
lctx.is_encoding = false;
const uint32_t n_tokens_all = batch_all.n_tokens;
// printf("n_tokens_all: %d\n", n_tokens_all);
if (n_tokens_all == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
@ -15144,10 +15442,28 @@ static int llama_decode_internal(
}
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
// if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
// }
// ggml_graph_dump_dot(gf, NULL, "phi3.dot");
// dot -Tpng phi3.dot > phi3.png
// const char *fname_cgraph = "phi3";
// // ggml_graph_export(gf, fname_cgraph);
// fprintf(stderr, "%s: exported compute graph to '%s'\n", __func__, fname_cgraph);
// print_tensor_cpp_int(gf->leafs[2], "inp_token[7,1]", 1);
// printf("dtype of inp_token: %d\n", gf->leafs[2]->type);
// print_tensor_cpp(gf->leafs[1], "token_embd.weight", 1);
// print_tensor_cpp(gf->nodes[0], "inp_embed[3072,7]", 1);
// print_tensor_cpp(gf->leafs[9], "attn.out.weight", 1);
// print_tensor_cpp(gf->nodes[28], "kqv_out-0", 1);
// print_tensor_cpp_int(gf->leafs[12], "inp_out_ids", 1);
// print_tensor_cpp(gf->nodes[41], "result_output", 1);
// print_tensor_cpp(gf->nodes[80], "l_out_1", 1);
// print_tensor_cpp(gf->nodes[1280], "l_out_31", 1);
// print_tensor_cpp(gf->nodes[1282], "result_output", 1);
// printf("num_nodes: %d\n", gf->n_nodes);
// print_tensor_cpp(gf->nodes[1286], "result_output", 1);
// print_tensor_cpp(gf->nodes[gf->n_nodes - 2], "gf->nodes[gf->n_nodes - 2] embd?", 1);
// extract logits
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
@ -15156,7 +15472,7 @@ static int llama_decode_internal(
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
const int32_t n_outputs_new = lctx.n_outputs;
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
@ -15211,7 +15527,7 @@ static int llama_decode_internal(
// set to total number of outputs in the batch, for use in llama_get_logits_ith
lctx.n_outputs = n_outputs;
// wait for the computation to finish (automatically done when obtaining the model output)
//llama_synchronize(&lctx);

Binary file not shown.