diff --git a/.gitignore b/.gitignore index 62b6b8b1a..d28f4d1b8 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ models-mnt /embedding /gguf /gguf-llama-simple +/gritlm /imatrix /infill /libllama.so diff --git a/Makefile b/Makefile index 4f26c0463..223d37eb4 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ BUILD_TARGETS = \ main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ - speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o + speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o # Binaries only useful for tests TEST_TARGETS = \ @@ -720,6 +720,10 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +gritlm: examples/gritlm/gritlm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 653abc73a..e762cf8b9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,6 +20,7 @@ else() add_subdirectory(convert-llama2c-to-ggml) add_subdirectory(embedding) add_subdirectory(finetune) + add_subdirectory(gritlm) add_subdirectory(infill) add_subdirectory(llama-bench) add_subdirectory(llava) diff --git a/examples/gritlm/CMakeLists.txt b/examples/gritlm/CMakeLists.txt new file mode 100644 index 000000000..ac4a5ae79 --- /dev/null +++ b/examples/gritlm/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET gritlm) +add_executable(${TARGET} gritlm.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp new file mode 100644 index 000000000..9b75d4f82 --- /dev/null +++ b/examples/gritlm/gritlm.cpp @@ -0,0 +1,243 @@ +#include "common.h" +#include "llama.h" + +#include +#include + +// #define GRIT_DEBUG + +static float dot_product(const std::vector & v1, const std::vector & v2) { + float dot = 0.0f; + for (uint64_t i = 0; i < v1.size(); ++i) { + dot += v1[i] * v2[i]; + } + return dot; +} + +static float norm(const std::vector & v) { + return std::sqrt(dot_product(v, v)); +} + +static float cosine_similarity(const std::vector & v1, const std::vector & v2) { + return dot_product(v1, v2) / (norm(v1) * norm(v2)); +} + +static void normalize(const std::vector & in, float * out) { + float inorm = norm(in); + for (uint64_t i = 0; i < in.size(); i++) { + out[i] = in[i] / inorm; + } +} + +static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { + auto result = std::vector>{}; + + auto mdl = llama_get_model(ctx); + auto batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + + for (uint64_t i = 0; i < sentences.size(); i++) { + llama_batch_clear(batch); + + std::string input_string = instruction + sentences[i]; + std::vector inputs = llama_tokenize(mdl, input_string, true, false); + auto n_toks = (int32_t)inputs.size(); + + // GritLM seems to have embed EOS = "" + // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 + // inputs.push_back(llama_token_eos(mdl)); + + // we want to ignore instruction tokens for mean pooling + std::vector inputs_instruct = llama_tokenize(mdl, instruction, true, false); + auto n_inst = (int32_t)inputs_instruct.size(); + +#ifdef GRIT_DEBUG + // debug tokens - should be matching as referenced in the GritLM sample + std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) { + std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str()); + }); + std::printf("\n"); +#endif + + // add input to batch (this increments n_tokens) + for (int32_t j = 0; j < n_toks; j++) { + llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); + } + + // clear previous kv_cache values (irrelevant for embeddings) + llama_kv_cache_clear(ctx); + + // run model + llama_decode(ctx, batch); + + // get embedding dimensions + uint64_t n_embd = llama_n_embd(mdl); + + // allocate embedding output + std::vector emb_unorm(n_embd, 0.0f); + + // sum up all token embeddings + for (int32_t k = n_inst; k < n_toks; k++) { + float * emb = llama_get_embeddings_ith(ctx, k); + for (uint64_t j = 0; j < n_embd; j++) { + emb_unorm[j] += emb[j]; + } + } + + // divide by number of tokens (mean pooling) + uint64_t n_sent = n_toks - n_inst; + for (uint64_t j = 0; j < n_embd; j++) { + emb_unorm[j] /= n_sent; + } + + auto emb_norm = std::vector(emb_unorm.size()); + normalize(emb_unorm, emb_norm.data()); + result.push_back(emb_norm); + +#ifdef GRIT_DEBUG + // print out emb_norm + std::printf("embedding %ld: ", i); + for (uint64_t j = 0; j < n_embd; j++) { + std::printf("%.5f ", emb_norm[j]); + } + std::printf("\n\n"); +#endif + } + + llama_batch_free(batch); + return result; +} + +static std::string aggregate_pieces(const std::vector & pieces) { + // calculate total length required + size_t length = 0; + for (const auto & str : pieces) { + length += str.size(); + } + + // reserve memory + std::string result; + result.reserve(length); + + // append pieces + for (const auto & str : pieces) { + result += str; + } + + return result; +} + +static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) { + std::vector pieces; + + const llama_model * mdl = llama_get_model(ctx); + llama_token eos_token = llama_token_eos(mdl); + llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); + + std::vector inputs = llama_tokenize(mdl, prompt, false, true); + int32_t i_current_token = 0; + + while (true) { + llama_batch_clear(bat); + for (auto i = 0; i < inputs.size(); i++) { + llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == inputs.size() - 1); + } + inputs.clear(); + + llama_decode(ctx, bat); + auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); + + auto candidates = std::vector(llama_n_vocab(mdl)); + for (auto token = 0; token < candidates.size(); token++) { + candidates[token] = llama_token_data{ token, logits[token], 0.0f }; + } + auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; + + llama_token token = llama_sample_token_greedy(ctx, &candidates_p); + if (token == eos_token) { + break; + } + + std::string piece = llama_token_to_piece(ctx, token); + if (stream) { + std::printf("%s", piece.c_str()); + } + + pieces.push_back(piece); + inputs.push_back(token); + } + + llama_batch_free(bat); + + return aggregate_pieces(pieces); +} + +static std::string gritlm_instruction(const std::string & instruction) { + return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n"; +} + +int main(int argc, char * argv[]) +{ + gpt_params params; + if (!gpt_params_parse(argc, argv, params)) { + return 1; + } + + llama_model_params mparams = llama_model_params_from_gpt_params(params); + llama_context_params cparams = llama_context_params_from_gpt_params(params); + + llama_backend_init(); + + llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); + + // create new context - set to embedding mode + llama_context * embd_ctx = llama_new_context_with_model(mdl, cparams); + llama_set_embeddings(embd_ctx, true); + + // create new context - default mode is causal + llama_context * causal_ctx = llama_new_context_with_model(mdl, cparams); + + // samples taken from here: https://github.com/ContextualAI/gritlm#basic + // Embedding/Representation + { + std::string instruction = "Given a scientific paper title, retrieve the paper's abstract"; + + std::vector queries = { + "Bitcoin: A Peer-to-Peer Electronic Cash System", + "Generative Representational Instruction Tuning", + }; + + std::vector documents = { + "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.", + "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.", + }; + + // No need to add instruction for retrieval documents + std::vector> d_rep = encode(embd_ctx, documents, gritlm_instruction("")); + std::vector> q_rep = encode(embd_ctx, queries, gritlm_instruction(instruction)); + + float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]); + float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]); + float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]); + float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]); + + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0); + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1); + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0); + std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1); + } + + // Generation + // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction + { + const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; + std::string response = generate(causal_ctx, prompt, true); + } + + llama_free(embd_ctx); + llama_free(causal_ctx); + + llama_free_model(mdl); + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp b/llama.cpp index e9192b4fa..991e1e673 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1684,7 +1684,6 @@ struct llama_cparams { bool embeddings; bool offload_kqv; - enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; @@ -8030,7 +8029,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (hparams.causal_attn) { + GGML_ASSERT( + (hparams.causal_attn || cparams.embeddings) && + "non-causal attention with generative models is not supported" + ); + + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + // But if cparams.embeddings is set, the attention will be non-causal nonetheless. + if (!cparams.embeddings) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = batch.n_tokens; @@ -8055,8 +8061,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } else { - // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used) + // for models using the kv cache, the mask needs to match the kv cache size const int64_t n_tokens = batch.n_tokens; + const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); @@ -8075,7 +8082,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f; + data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; } } } @@ -13158,6 +13169,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) ctx->abort_callback_data = abort_callback_data; } +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { + ctx->cparams.embeddings = embeddings; +} + struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens, diff --git a/llama.h b/llama.h index 3dc162b07..0fe7b0105 100644 --- a/llama.h +++ b/llama.h @@ -641,6 +641,10 @@ extern "C" { // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + // Set whether to use causal attention or not + // If set to true, the model will only attend to the past tokens + LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); + // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);