diff --git a/Makefile b/Makefile index 02ba3e36d..160fbc020 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam_search +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam_search fill-in-middle # Binaries only useful for tests TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1 @@ -427,6 +427,9 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS) beam_search: examples/beam_search/beam_search.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +fill-in-middle: examples/fill-in-middle/FIM.c ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))' BUILD_TARGETS += metal endif diff --git a/examples/fill-in-middle/CMakeLists.txt b/examples/fill-in-middle/CMakeLists.txt new file mode 100644 index 000000000..9150d7c1c --- /dev/null +++ b/examples/fill-in-middle/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET FIM) +add_executable(${TARGET} FIM.c) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/fill-in-middle/FIM.c b/examples/fill-in-middle/FIM.c index 5283b7f47..307a31e63 100644 --- a/examples/fill-in-middle/FIM.c +++ b/examples/fill-in-middle/FIM.c @@ -8,15 +8,26 @@ For a quick summary of what's going on here, see issue #2818. */ -static inline struct llama_context* create_codellama_fim_context(const char* model_path) { +static inline struct llama_context* codellama_create_fim_context(const char* model_path, const char** error_message) { struct llama_context_params params = llama_context_default_params(); struct llama_model* model = llama_load_model_from_file(model_path, params); + if (!model) { + *error_message = "Failed to load model."; + return NULL; + } + struct llama_context* context = llama_new_context_with_model(model, params); + if (!context) { + *error_message = "Failed to create context."; + llama_free_model(model); + return NULL; + } + return context; } static inline char* -codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, char** error_message) { +codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) { int num_tokens; llama_token* tokens_end = (llama_token*)malloc(sizeof(llama_token) * n_max_tokens); @@ -45,48 +56,60 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch } // Append middle token - int num_prompt_tokens = (int)(tokens_end - tokens); *tokens_end++ = llama_token_middle(ctx); // Evaluate the LM on the prompt. - if (llama_eval(ctx, tokens, num_prompt_tokens, 0, n_threads)) { + if (llama_eval(ctx, tokens, (int)(tokens_end - tokens), 0, n_threads)) { *error_message = "Failed to evaluate the prompt."; free(tokens); return NULL; } // Generate tokens until n_max_tokens or the token is generated. - int num_generated_tokens = 0; llama_token* generated_tokens = NULL; + size_t num_generated_tokens = 0; + int vocab_size = llama_n_vocab(ctx); for (size_t i = 0; i < n_max_tokens; i++) { - // Evaluate the LM for a single token - llama_token* current_token = generated_tokens + num_generated_tokens; - if (llama_eval(ctx, current_token, 1, num_generated_tokens, n_threads)) { + // Evaluate the LM for a single token, obtaining the logits and probabilities. + if (llama_eval(ctx, &generated_tokens[num_generated_tokens], 1, (int)num_generated_tokens, n_threads)) { *error_message = "Failed to evaluate the prompt."; free(tokens); break; } + float* logits = llama_get_logits(ctx); - if (*current_token == llama_token_eot(ctx)) { + // From the logits, select the most likely token. + float highest_log_likelihood = -1; + llama_token likeliest_token = -1; + for (llama_token token_id = 0; token_id < vocab_size; token_id++) { + if (logits[token_id] > highest_log_likelihood) { + highest_log_likelihood = logits[token_id]; + likeliest_token = token_id; + } + } + + // Don't add the token if it's . + if (likeliest_token == llama_token_eot(ctx)) { break; } - num_generated_tokens++; + // Append the token, so it's there for subsequent evaluations. + generated_tokens[num_generated_tokens++] = likeliest_token; } // Allocate memory for the final result size_t result_length = 0; size_t result_capacity = 4096; - char* result = (char*)malloc(sizeof(char) * 4096); + char* result = (char*)malloc(sizeof(char) * result_capacity); if (!result) { *error_message = "Failed to allocate memory for result."; free(tokens); return NULL; } - // Translate tokens to string + // Translate tokens to string, growing the allocation if it's too small. for (size_t i = 0; i < num_generated_tokens; i++) { - int appended = llama_token_to_str(ctx, generated_tokens[i], result, result_capacity - result_length); + int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length); if (appended < 0) { // retry the token with a larger buffer i--; @@ -116,20 +139,27 @@ int main(int argc, char** argv) { return 1; } - struct llama_context* ctx = create_codellama_fim_context(argv[1]); - + char* model = argv[1]; + char* prefix = argv[2]; + char* suffix = argv[3]; size_t n_max_tokens = atoi(argv[4]); int n_threads = atoi(argv[5]); bool spm = false; - char* error_message = NULL; - char* result = codellama_fill_in_middle(ctx, argv[2], argv[3], n_max_tokens, n_threads, spm, &error_message); + const char* error_message = NULL; + struct llama_context* ctx = codellama_create_fim_context(model, &error_message); if (error_message) { fprintf(stderr, "Error: %s\n", error_message); return 1; } - printf("%s\n", result); + char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message); + if (error_message) { + fprintf(stderr, "Error: %s\n", error_message); + return 1; + } + + printf("%s%s%s\n", prefix, result, suffix); llama_free(ctx); } diff --git a/llama.cpp b/llama.cpp index 35960e7f6..b1dae5304 100644 --- a/llama.cpp +++ b/llama.cpp @@ -960,6 +960,7 @@ struct llama_vocab { id linefeed_id = 13; // codellama FIM special tokens + // TODO: load these from the vocabulary. id special_prefix_id = 32007; id special_middle_id = 32009; id special_suffix_id = 32008;