From 314c29cc32c1a57aecc11aeb36917d424e70f99f Mon Sep 17 00:00:00 2001 From: apaz-cli Date: Sun, 3 Sep 2023 16:32:55 -0500 Subject: [PATCH] Debugging crash. --- examples/fill-in-middle/FIM.c | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/examples/fill-in-middle/FIM.c b/examples/fill-in-middle/FIM.c index 307a31e63..3ea6d73c0 100644 --- a/examples/fill-in-middle/FIM.c +++ b/examples/fill-in-middle/FIM.c @@ -1,4 +1,5 @@ #include +#include #include #include "../../llama.h" @@ -8,7 +9,8 @@ For a quick summary of what's going on here, see issue #2818. */ -static inline struct llama_context* codellama_create_fim_context(const char* model_path, const char** error_message) { +static inline struct llama_context* +codellama_create_fim_context(const char* model_path, const char** error_message) { struct llama_context_params params = llama_context_default_params(); struct llama_model* model = llama_load_model_from_file(model_path, params); if (!model) { @@ -30,7 +32,9 @@ static inline char* codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) { int num_tokens; - llama_token* tokens_end = (llama_token*)malloc(sizeof(llama_token) * n_max_tokens); + size_t combined_len = strlen(prefix) + strlen(suffix) + 3; + size_t initial_size = sizeof(llama_token) * combined_len; + llama_token* tokens_end = (llama_token*)malloc(initial_size); llama_token* tokens = tokens_end; if (!tokens) { *error_message = "Failed to allocate memory for tokens."; @@ -58,15 +62,28 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch // Append middle token *tokens_end++ = llama_token_middle(ctx); + // Grow to accommodate the prompt and the max amount of generated tokens + size_t prompt_len = (size_t)(tokens_end - tokens); + size_t min_len = (prompt_len + n_max_tokens); + if (min_len > combined_len) { + llama_token* new_tokens = (llama_token*)realloc(tokens, sizeof(llama_token) * min_len); + if (!new_tokens) { + *error_message = "Failed to allocate memory for tokens."; + free(tokens); + return NULL; + } + tokens = new_tokens; + } + // Evaluate the LM on the prompt. - if (llama_eval(ctx, tokens, (int)(tokens_end - tokens), 0, n_threads)) { + if (llama_eval(ctx, tokens, prompt_len, 0, n_threads)) { *error_message = "Failed to evaluate the prompt."; free(tokens); return NULL; } // Generate tokens until n_max_tokens or the token is generated. - llama_token* generated_tokens = NULL; + llama_token* generated_tokens = tokens + prompt_len; size_t num_generated_tokens = 0; int vocab_size = llama_n_vocab(ctx); for (size_t i = 0; i < n_max_tokens; i++) { @@ -111,8 +128,7 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch for (size_t i = 0; i < num_generated_tokens; i++) { int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length); if (appended < 0) { - // retry the token with a larger buffer - i--; + i--; // retry the token with a larger buffer size_t new_capacity = result_capacity * 2; char* new_result = (char*)realloc(result, sizeof(char) * new_capacity); if (!new_result) { @@ -142,24 +158,28 @@ int main(int argc, char** argv) { char* model = argv[1]; char* prefix = argv[2]; char* suffix = argv[3]; - size_t n_max_tokens = atoi(argv[4]); + size_t n_max_tokens = atoi(argv[4]) > 0 ? atoi(argv[4]) : 64; int n_threads = atoi(argv[5]); bool spm = false; const char* error_message = NULL; + puts("Loading the model. This could take quite a while..."); struct llama_context* ctx = codellama_create_fim_context(model, &error_message); if (error_message) { fprintf(stderr, "Error: %s\n", error_message); return 1; } + puts("Model loaded. Generating text..."); char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message); if (error_message) { fprintf(stderr, "Error: %s\n", error_message); return 1; } + puts("Generated text:"); printf("%s%s%s\n", prefix, result, suffix); + free(result); llama_free(ctx); }