From 9d866118c4169fecf052d9b541f963176cfacc91 Mon Sep 17 00:00:00 2001 From: ningshanwutuobang Date: Sun, 25 Jun 2023 01:21:54 +0800 Subject: [PATCH] refactor the interface and fixed the styles --- Makefile | 10 +- examples/{embd_input => embd-input}/README.md | 2 +- examples/embd-input/embd-input-lib.cpp | 218 ++++++++++++++ .../embd-input-test.cpp} | 15 +- .../embd_input.h => embd-input/embd-input.h} | 0 .../{embd_input => embd-input}/embd_input.py | 58 ++-- examples/{embd_input => embd-input}/llava.py | 8 +- .../{embd_input => embd-input}/panda_gpt.py | 38 ++- examples/embd_input/embd_input_lib.cpp | 283 ------------------ llama.cpp | 141 ++++----- llama.h | 4 +- 11 files changed, 335 insertions(+), 442 deletions(-) rename examples/{embd_input => embd-input}/README.md (98%) create mode 100644 examples/embd-input/embd-input-lib.cpp rename examples/{embd_input/embd_input_test.cpp => embd-input/embd-input-test.cpp} (71%) rename examples/{embd_input/embd_input.h => embd-input/embd-input.h} (100%) rename examples/{embd_input => embd-input}/embd_input.py (64%) rename examples/{embd_input => embd-input}/llava.py (94%) rename examples/{embd_input => embd-input}/panda_gpt.py (81%) delete mode 100644 examples/embd_input/embd_input_lib.cpp diff --git a/Makefile b/Makefile index 41fb6d7c7..2fc442447 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple libembd_input.so embd_input_test +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple libembdinput.so embd-input-test ifdef LLAMA_BUILD_SERVER BUILD_TARGETS += server @@ -302,12 +302,12 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml. server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) -libembd_input.so: examples/embd_input/embd_input.h examples/embd_input/embd_input_lib.cpp examples/embd_input/embd_input_test.cpp build-info.h ggml.o llama.o common.o $(OBJS) - $(CXX) --shared $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) +libembdinput.so: examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) -embd_input_test: libembd_input.so examples/embd_input/embd_input_test.cpp build-info.h ggml.o llama.o common.o $(OBJS) - $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -Wl,-rpath=./ -lembd_input +embd-input-test: libembdinput.so examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/examples/embd_input/README.md b/examples/embd-input/README.md similarity index 98% rename from examples/embd_input/README.md rename to examples/embd-input/README.md index c180d541a..eb1095f24 100644 --- a/examples/embd_input/README.md +++ b/examples/embd-input/README.md @@ -1,6 +1,6 @@ ### Examples for input embedding directly -## Requirement +## Requirement build `libembd_input.so` run the following comman in main dir (../../). ``` diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp new file mode 100644 index 000000000..37a5b5208 --- /dev/null +++ b/examples/embd-input/embd-input-lib.cpp @@ -0,0 +1,218 @@ +// Defines sigaction on msys: +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "embd-input.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static llama_context ** g_ctx; + +extern "C" { + +struct MyModel* create_mymodel(int argc, char ** argv) { + gpt_params params; + + if (gpt_params_parse(argc, argv, params) == false) { + return nullptr; + } + + fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); + + if (params.seed < 0) { + params.seed = time(NULL); + } + fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + + llama_init_backend(); + + llama_context * ctx; + g_ctx = &ctx; + + // load the model and apply lora adapter, if any + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return nullptr; + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + } + struct MyModel * ret = new MyModel(); + ret->ctx = ctx; + ret->params = params; + ret->n_past = 0; + // printf("ctx: %d\n", ret->ctx); + return ret; +} + +void free_mymodel(struct MyModel * mymodel) { + llama_context * ctx = mymodel->ctx; + llama_print_timings(ctx); + llama_free(ctx); + delete mymodel; +} + + +bool eval_float(void * model, float * input, int N){ + MyModel * mymodel = (MyModel*)model; + llama_context * ctx = mymodel->ctx; + gpt_params params = mymodel->params; + int n_emb = llama_n_embd(ctx); + int n_past = mymodel->n_past; + int n_batch = N; // params.n_batch; + + for (int i = 0; i < (int) N; i += n_batch) { + int n_eval = (int) N - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + mymodel->n_past = n_past; + return true; +} + +bool eval_tokens(void * model, std::vector tokens) { + MyModel * mymodel = (MyModel* )model; + llama_context * ctx; + ctx = mymodel->ctx; + gpt_params params = mymodel->params; + int n_past = mymodel->n_past; + for (int i = 0; i < (int) tokens.size(); i += params.n_batch) { + int n_eval = (int) tokens.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + mymodel->n_past = n_past; + return true; +} + +bool eval_id(struct MyModel* mymodel, int id) { + std::vector tokens; + tokens.push_back(id); + return eval_tokens(mymodel, tokens); +} + +bool eval_string(struct MyModel * mymodel,const char* str){ + llama_context * ctx = mymodel->ctx; + std::string str2 = str; + std::vector embd_inp = ::llama_tokenize(ctx, str2, true); + eval_tokens(mymodel, embd_inp); + return true; +} + +llama_token sampling_id(struct MyModel* mymodel) { + llama_context* ctx = mymodel->ctx; + gpt_params params = mymodel->params; + // int n_ctx = llama_n_ctx(ctx); + + // out of user input, sample next token + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + // const float repeat_penalty = params.repeat_penalty; + // const float alpha_presence = params.presence_penalty; + // const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + // const bool penalize_nl = params.penalize_nl; + + llama_token id = 0; + { + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // TODO: Apply penalties + // float nl_logit = logits[llama_token_nl()]; + // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); + // llama_sample_repetition_penalty(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, repeat_penalty); + // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, alpha_frequency, alpha_presence); + // if (!penalize_nl) { + // logits[llama_token_nl()] = nl_logit; + // } + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &candidates_p); + } else { + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k(ctx, &candidates_p, top_k, 1); + llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token(ctx, &candidates_p); + } + } + } + + return id; +} + +const char * sampling(struct MyModel * mymodel) { + llama_context * ctx = mymodel->ctx; + int id = sampling_id(mymodel); + std::string ret; + if (id == llama_token_eos()) ret = ""; + else ret = llama_token_to_str(ctx, id); + eval_id(mymodel, id); + return ret.c_str(); +} + +} diff --git a/examples/embd_input/embd_input_test.cpp b/examples/embd-input/embd-input-test.cpp similarity index 71% rename from examples/embd_input/embd_input_test.cpp rename to examples/embd-input/embd-input-test.cpp index e14141497..e5e040f62 100644 --- a/examples/embd_input/embd_input_test.cpp +++ b/examples/embd-input/embd-input-test.cpp @@ -1,4 +1,4 @@ -#include "embd_input.h" +#include "embd-input.h" #include #include #include @@ -7,8 +7,11 @@ int main(int argc, char** argv) { auto mymodel = create_mymodel(argc, argv); int N = 10; + int max_tgt_len = 500; int n_embd = llama_n_embd(mymodel->ctx); - float* data = new float[N*n_embd]; + + // add random float embd to test evaluation + float * data = new float[N*n_embd]; std::default_random_engine e; std::uniform_real_distribution u(0,1); for (int i=0;iparams.prompt.c_str()); const char* tmp; - for (int i=0;i < 500; i++) { - // int id = sampling_id(mymodel); + for (int i=0; i")==0) break; - printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id)); + printf("%s", tmp); fflush(stdout); - // eval_id(mymodel, id); } printf("\n"); free_mymodel(mymodel); diff --git a/examples/embd_input/embd_input.h b/examples/embd-input/embd-input.h similarity index 100% rename from examples/embd_input/embd_input.h rename to examples/embd-input/embd-input.h diff --git a/examples/embd_input/embd_input.py b/examples/embd-input/embd_input.py similarity index 64% rename from examples/embd_input/embd_input.py rename to examples/embd-input/embd_input.py index ce057a89d..be2896614 100644 --- a/examples/embd_input/embd_input.py +++ b/examples/embd-input/embd_input.py @@ -1,8 +1,9 @@ import ctypes from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int import numpy as np +import os -libc = cdll.LoadLibrary("./libembd_input.so") +libc = cdll.LoadLibrary("./libembdinput.so") libc.sampling.restype=c_char_p libc.create_mymodel.restype=c_void_p libc.eval_string.argtypes=[c_void_p, c_char_p] @@ -16,7 +17,9 @@ class MyModel: c_str = [c_char_p(i.encode()) for i in args] args_c = (c_char_p * argc)(*c_str) self.model = c_void_p(libc.create_mymodel(argc, args_c)) -# print("self.model", self.model) + self.max_tgt_len = 512 + self.print_string_eval = True + def __del__(self): libc.free_mymodel(self.model) @@ -25,6 +28,8 @@ class MyModel: def eval_string(self, x): libc.eval_string(self.model, x.encode()) # c_char_p(x.encode())) + if self.print_string_eval: + print(x) def eval_token(self, x): libc.eval_id(self.model, x) @@ -33,49 +38,34 @@ class MyModel: s = libc.sampling(self.model) return s - def generate(self, end=""): - ret = b"" - end = end.encode() - for _ in range(500): - tmp = self.sampling() # .decode() - if (ret+tmp).endswith(end): - break - ret += tmp - return ret.decode() - def stream_generate(self, end=""): ret = b"" end = end.encode() - head = b"" - for _ in range(500): - tmp = self.sampling() # .decode() + for _ in range(self.max_tgt_len): + tmp = self.sampling() ret += tmp - try: - text = (head + tmp).decode() - print(text, end="") - head = b"" - except: - head += text + yield tmp if ret.endswith(end): break - print("") - return ret.decode() + def generate_with_print(self, end=""): + ret = b"" + for i in self.stream_generate(end=end): + ret += i + print(i.decode(errors="replace"), end="", flush=True) + print("") + return ret.decode(errors="replace") + + + def generate(self, end=""): + text = b"".join(self.stream_generate(end=end)) + return text.decode(errors="replace") if __name__ == "__main__": model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"]) - # print(model) model.eval_string("""user: what is the color of the flag of UN?""") - # model.eval_token(100) x = np.random.random((5120,10))# , dtype=np.float32) model.eval_float(x) model.eval_string("""assistant:""") - # print(x[0,0], x[0,1],x[1,0]) - # model.eval_float(x) - # print(libc) - - for i in range(500): - tmp = model.sampling().decode() - if tmp == "": - break - print(tmp, end="", flush=True) + for i in model.generate(): + print(i.decode(errors="replace"), end="", flush=True) diff --git a/examples/embd_input/llava.py b/examples/embd-input/llava.py similarity index 94% rename from examples/embd_input/llava.py rename to examples/embd-input/llava.py index a1efaddf6..2f20cb722 100644 --- a/examples/embd_input/llava.py +++ b/examples/embd-input/llava.py @@ -31,7 +31,7 @@ class Llava: self.model.eval_string("user: ") self.model.eval_string(question) self.model.eval_string("\nassistant: ") - return self.model.generate() + return self.model.generate_with_print() def chat_with_image(self, image, question): with torch.no_grad(): @@ -49,7 +49,7 @@ class Llava: self.model.eval_token(32003-1) # im_end self.model.eval_string(question) self.model.eval_string("\nassistant: ") - return self.model.generate() + return self.model.generate_with_print() if __name__=="__main__": @@ -63,8 +63,8 @@ if __name__=="__main__": respose = a.chat_with_image( Image.open("./media/llama1-logo.png").convert('RGB'), "what is the text in the picture?") - print(respose) - print(a.chat("what is the color of it?")) + respose + a.chat("what is the color of it?") diff --git a/examples/embd_input/panda_gpt.py b/examples/embd-input/panda_gpt.py similarity index 81% rename from examples/embd_input/panda_gpt.py rename to examples/embd-input/panda_gpt.py index b1199b95d..0cfac5f32 100644 --- a/examples/embd_input/panda_gpt.py +++ b/examples/embd-input/panda_gpt.py @@ -7,11 +7,13 @@ from torch import nn import torch # use PandaGPT path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "PandaGPT","code","model")) +panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT") +imagebind_ckpt_path = "./models/panda_gpt/" + +sys.path.insert(0, os.path.join(panda_gpt_path,"code","model")) from ImageBind.models import imagebind_model from ImageBind import data -imagebind_ckpt_path = "./models/panda_gpt/" ModalityType = imagebind_model.ModalityType max_tgt_len = 400 @@ -31,25 +33,25 @@ class PandaGPT: "weight": state["llama_proj.weight"], "bias": state["llama_proj.bias"]}) + def eval_inputs(self, inputs): + self.model.eval_string("") + embds = self.extract_multimoal_feature(inputs) + for i in embds: + self.model.eval_float(i.T) + self.model.eval_string(" ") + def chat(self, question): - if self.generated_text == "": - self.model.eval_string("###") - self.model.eval_string(" Human: ") - self.model.eval_string(question) - self.model.eval_string("\n### Assistant:") - ret = self.model.stream_generate(end="###") - self.generated_text += ret - return ret + return self.chat_with_image(None, question) def chat_with_image(self, inputs, question): if self.generated_text == "": self.model.eval_string("###") - self.model.eval_string(" Human: ") - embds = self.extract_multimoal_feature(inputs) - for i in embds: - self.model.eval_float(i.T) - self.model.eval_string(" " + question + "\n### Assistant:") - ret = self.model.stream_generate(end="###") + self.model.eval_string(" Human: ") + if inputs: + self.eval_inputs(inputs) + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + ret = self.model.generate_with_print(end="###") self.generated_text += ret return ret @@ -88,13 +90,9 @@ class PandaGPT: if __name__=="__main__": - # model form liuhaotian/LLaVA-13b-delta-v1-1 a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"]) - # Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin. - # Also here can use pytorch_model-00003-of-00003.bin directly. a.load_projection("./models/panda_gpt/adapter_model.bin") a.chat_with_image( {"image_paths": ["./media/llama1-logo.png"]}, "what is the text in the picture? 'llama' or 'lambda'?") a.chat("what is the color of it?") - diff --git a/examples/embd_input/embd_input_lib.cpp b/examples/embd_input/embd_input_lib.cpp deleted file mode 100644 index bbdf6d645..000000000 --- a/examples/embd_input/embd_input_lib.cpp +++ /dev/null @@ -1,283 +0,0 @@ -// Defines sigaction on msys: -#ifndef _GNU_SOURCE -#define _GNU_SOURCE -#endif - -#include "embd_input.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) -#include -#include -#elif defined (_WIN32) -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#include -#endif - -static console_state con_st; -static llama_context ** g_ctx; - -static bool is_interacting = false; - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) -void sigint_handler(int signo) { - if (signo == SIGINT) { - if (!is_interacting) { - is_interacting=true; - } else { - console_cleanup(con_st); - printf("\n"); - llama_print_timings(*g_ctx); - _exit(130); - } - } -} -#endif - - -extern "C" { - -struct MyModel* create_mymodel(int argc, char ** argv) { - gpt_params params; - - if (gpt_params_parse(argc, argv, params) == false) { - return nullptr; - } - - - if (params.n_ctx > 2048) { - fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" - "expect poor results\n", __func__, params.n_ctx); - } - - fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); - - if (params.seed < 0) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - - std::mt19937 rng(params.seed); - if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); - } - - llama_init_backend(); - - llama_context * ctx; - g_ctx = &ctx; - - // load the model and apply lora adapter, if any - ctx = llama_init_from_gpt_params(params); - if (ctx == NULL) { - fprintf(stderr, "%s: error: unable to load model\n", __func__); - return nullptr; - } - - // print system information - { - fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); - } - struct MyModel* ret= new MyModel(); - ret->ctx = ctx; - ret->params = params; - ret->n_past = 0; - // printf("ctx: %d\n", ret->ctx); - return ret; -} - -void free_mymodel(struct MyModel* mymodel) { - llama_context* ctx = mymodel->ctx; - llama_print_timings(ctx); - llama_free(ctx); - delete mymodel; -} - - -bool eval_float(void* model, float* input, int N){ - MyModel* mymodel = (MyModel* )model; - llama_context* ctx = mymodel->ctx; - gpt_params params = mymodel->params; - int n_emb = llama_n_embd(ctx); - int n_past = mymodel->n_past; - // printf("%f,%f\n", *input, *(input+1)); - int n_batch = N; // params.n_batch; - for (int i = 0; i < (int) N; i += n_batch) { - int n_eval = (int) N - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - if (llama_eval_float(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return false; - } - n_past += n_eval; - } - mymodel->n_past = n_past; - return true; -} - - - - - -bool eval_tokens(void* model, std::vector tokens) { - MyModel* mymodel = (MyModel* )model; - // printf("model: %d\n", mymodel); - llama_context* ctx;// = mymodel->ctx; - // printf("ctx2: %d\n", ctx); - // printf("ctx2: %d\n", mymodel->ctx); - ctx = mymodel->ctx; - // printf("ctx2: %d\n", ctx); - gpt_params params = mymodel->params; - // printf("\n%d\n", params); - int n_past = mymodel->n_past; - for (int i = 0; i < (int) tokens.size(); i += params.n_batch) { - int n_eval = (int) tokens.size() - i; - if (n_eval > params.n_batch) { - n_eval = params.n_batch; - } - // printf("%d, %d, %d\n", i, n_eval, n_past); - if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return false; - } - n_past += n_eval; - } - mymodel->n_past = n_past; - return true; -} - -bool eval_id(struct MyModel* mymodel, int id) { - // printf("%d\n", id); - std::vector tokens; - tokens.push_back(id); - // printf("%d\n", tokens.size()); - // printf("%d\n", tokens[0]); - return eval_tokens(mymodel, tokens); -} - - -bool eval_string(struct MyModel* mymodel,const char* str){ - // std::cout << "eval " << std::endl; - // printf("%s", str); - llama_context* ctx = mymodel->ctx; - std::string str2 = str; - // printf("%s", str2.c_str()); - std::cout << str2 << std::endl; - std::vector embd_inp = ::llama_tokenize(ctx, str2, true); - eval_tokens(mymodel, embd_inp); - return true; -} - - - - -llama_token sampling_id(struct MyModel* mymodel) { - llama_context* ctx = mymodel->ctx; - gpt_params params = mymodel->params; - // int n_ctx = llama_n_ctx(ctx); - - - // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - // const float repeat_penalty = params.repeat_penalty; - // const float alpha_presence = params.presence_penalty; - // const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - // const bool penalize_nl = params.penalize_nl; - - llama_token id = 0; - - { - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); - - // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - // Apply penalties -// float nl_logit = logits[llama_token_nl()]; -// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); -// llama_sample_repetition_penalty(ctx, &candidates_p, -// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, -// last_n_repeat, repeat_penalty); -// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, -// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, -// last_n_repeat, alpha_frequency, alpha_presence); -// if (!penalize_nl) { -// logits[llama_token_nl()] = nl_logit; -// } - - if (temp <= 0) { - // Greedy sampling - id = llama_sample_token_greedy(ctx, &candidates_p); - } else { - if (mirostat == 1) { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } else if (mirostat == 2) { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } else { - // Temperature sampling - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); - llama_sample_typical(ctx, &candidates_p, typical_p, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token(ctx, &candidates_p); - } - } - - } - return id; -} - -const char* sampling(struct MyModel* mymodel) { - llama_context* ctx = mymodel->ctx; - int id = sampling_id(mymodel); - - std::string ret; - if (id == llama_token_eos()) ret = ""; - else ret = llama_token_to_str(ctx, id); - eval_id(mymodel, id); - return ret.c_str(); -} - -} diff --git a/llama.cpp b/llama.cpp index 65890663b..9c983da3a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1342,15 +1342,33 @@ static bool llama_model_load( } } -static bool llama_eval_internal_tensor( - llama_context& lctx, - ggml_context* ctx0, - ggml_tensor* inpL, - const int n_tokens, - const int n_past, - const int n_threads, - const char * cgraph_fname, - const int64_t t_start_us) { +// evaluate the transformer +// +// - lctx: llama context +// - tokens: new batch of tokens to process +// - n_tokens number of tokens +// - embd embeddings input +// - n_past: the context size so far +// - n_threads: number of threads to use +// +static bool llama_eval_internal( + llama_context & lctx, + const llama_token * tokens, + const int n_tokens, + const float * embd, + const int n_past, + const int n_threads, + const char * cgraph_fname) { + + LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); + + // enforce that the first token is BOS + if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) { + fprintf(stderr, "%s: first token must be BOS\n", __func__); + return false; + } + + const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1359,7 +1377,6 @@ static bool llama_eval_internal_tensor( const auto & kv_self = model.kv_self; - LLAMA_ASSERT(!!kv_self.ctx); const int n_embd = hparams.n_embd; @@ -1371,6 +1388,15 @@ static bool llama_eval_internal_tensor( const int n_gpu_layers = model.n_gpu_layers; auto & mem_per_token = lctx.mem_per_token; + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.addr, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx0 = ggml_init(params); // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance @@ -1378,6 +1404,17 @@ static bool llama_eval_internal_tensor( gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_set_name(embd, "embd"); + memcpy(embd->data, tokens, N*ggml_element_size(embd)); + inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); + } else { + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + } const int i_gpu_start = n_layer - n_gpu_layers; (void) i_gpu_start; @@ -1746,53 +1783,6 @@ static bool llama_eval_internal_tensor( return true; } - -// evaluate the transformer -// -// - lctx: llama context -// - tokens: new batch of tokens to process -// - n_past: the context size so far -// - n_threads: number of threads to use -// -static bool llama_eval_internal( - llama_context & lctx, - const llama_token * tokens, - const int n_tokens, - const int n_past, - const int n_threads, - const char * cgraph_fname) { - - // enforce that the first token is BOS - if (n_past == 0 && tokens[0] != llama_token_bos()) { - fprintf(stderr, "%s: first token must be BOS\n", __func__); - return false; - } - - const auto & model = lctx.model; - - const int64_t t_start_us = ggml_time_us(); - - const int N = n_tokens; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.addr, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_set_name(embd, "embd"); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); - - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); - return llama_eval_internal_tensor(lctx, ctx0, inpL, N, n_past, n_threads, cgraph_fname, t_start_us); -} - // // tokenizer // @@ -3357,7 +3347,7 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, nullptr, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -3373,32 +3363,13 @@ int llama_eval( } -int llama_eval_float( - struct llama_context * ctx, - const float * input, - int n_tokens, - int n_past, - int n_threads) { - const auto & model = ctx->model; - - const int64_t t_start_us = ggml_time_us(); - - const int N = n_tokens; - - auto & buf_compute = ctx->buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.addr, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - struct ggml_tensor *inpL = - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, N); - memcpy(inpL->data, input, N * model.hparams.n_embd * ggml_element_size(inpL)); - if (!llama_eval_internal_tensor(*ctx, ctx0, inpL, N, n_past, n_threads, nullptr, t_start_us)) { +int llama_eval_embd( + struct llama_context * ctx, + const float * embd, + int n_tokens, + int n_past, + int n_threads) { + if (!llama_eval_internal(*ctx, nullptr, n_tokens, embd, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -3419,7 +3390,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) { const std::vector tmp(n_batch, llama_token_bos()); - if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) { + if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), nullptr, n_ctx, 1, fname)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } diff --git a/llama.h b/llama.h index ad975166b..2183b12fa 100644 --- a/llama.h +++ b/llama.h @@ -200,9 +200,9 @@ extern "C" { int n_threads); // Same as llama_eval, but use float matrix input directly. - LLAMA_API int llama_eval_float( + LLAMA_API int llama_eval_embd( struct llama_context * ctx, - const float * embds, + const float * embd, int n_tokens, int n_past, int n_threads);