Debugging crash.
This commit is contained in:
parent
142d79b459
commit
314c29cc32
1 changed files with 27 additions and 7 deletions
|
@ -1,4 +1,5 @@
|
|||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include "../../llama.h"
|
||||
|
||||
|
@ -8,7 +9,8 @@ For a quick summary of what's going on here, see issue #2818.
|
|||
*/
|
||||
|
||||
|
||||
static inline struct llama_context* codellama_create_fim_context(const char* model_path, const char** error_message) {
|
||||
static inline struct llama_context*
|
||||
codellama_create_fim_context(const char* model_path, const char** error_message) {
|
||||
struct llama_context_params params = llama_context_default_params();
|
||||
struct llama_model* model = llama_load_model_from_file(model_path, params);
|
||||
if (!model) {
|
||||
|
@ -30,7 +32,9 @@ static inline char*
|
|||
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) {
|
||||
|
||||
int num_tokens;
|
||||
llama_token* tokens_end = (llama_token*)malloc(sizeof(llama_token) * n_max_tokens);
|
||||
size_t combined_len = strlen(prefix) + strlen(suffix) + 3;
|
||||
size_t initial_size = sizeof(llama_token) * combined_len;
|
||||
llama_token* tokens_end = (llama_token*)malloc(initial_size);
|
||||
llama_token* tokens = tokens_end;
|
||||
if (!tokens) {
|
||||
*error_message = "Failed to allocate memory for tokens.";
|
||||
|
@ -58,15 +62,28 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch
|
|||
// Append middle token
|
||||
*tokens_end++ = llama_token_middle(ctx);
|
||||
|
||||
// Grow to accommodate the prompt and the max amount of generated tokens
|
||||
size_t prompt_len = (size_t)(tokens_end - tokens);
|
||||
size_t min_len = (prompt_len + n_max_tokens);
|
||||
if (min_len > combined_len) {
|
||||
llama_token* new_tokens = (llama_token*)realloc(tokens, sizeof(llama_token) * min_len);
|
||||
if (!new_tokens) {
|
||||
*error_message = "Failed to allocate memory for tokens.";
|
||||
free(tokens);
|
||||
return NULL;
|
||||
}
|
||||
tokens = new_tokens;
|
||||
}
|
||||
|
||||
// Evaluate the LM on the prompt.
|
||||
if (llama_eval(ctx, tokens, (int)(tokens_end - tokens), 0, n_threads)) {
|
||||
if (llama_eval(ctx, tokens, prompt_len, 0, n_threads)) {
|
||||
*error_message = "Failed to evaluate the prompt.";
|
||||
free(tokens);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Generate tokens until n_max_tokens or the <EOT> token is generated.
|
||||
llama_token* generated_tokens = NULL;
|
||||
llama_token* generated_tokens = tokens + prompt_len;
|
||||
size_t num_generated_tokens = 0;
|
||||
int vocab_size = llama_n_vocab(ctx);
|
||||
for (size_t i = 0; i < n_max_tokens; i++) {
|
||||
|
@ -111,8 +128,7 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch
|
|||
for (size_t i = 0; i < num_generated_tokens; i++) {
|
||||
int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length);
|
||||
if (appended < 0) {
|
||||
// retry the token with a larger buffer
|
||||
i--;
|
||||
i--; // retry the token with a larger buffer
|
||||
size_t new_capacity = result_capacity * 2;
|
||||
char* new_result = (char*)realloc(result, sizeof(char) * new_capacity);
|
||||
if (!new_result) {
|
||||
|
@ -142,24 +158,28 @@ int main(int argc, char** argv) {
|
|||
char* model = argv[1];
|
||||
char* prefix = argv[2];
|
||||
char* suffix = argv[3];
|
||||
size_t n_max_tokens = atoi(argv[4]);
|
||||
size_t n_max_tokens = atoi(argv[4]) > 0 ? atoi(argv[4]) : 64;
|
||||
int n_threads = atoi(argv[5]);
|
||||
bool spm = false;
|
||||
const char* error_message = NULL;
|
||||
|
||||
puts("Loading the model. This could take quite a while...");
|
||||
struct llama_context* ctx = codellama_create_fim_context(model, &error_message);
|
||||
if (error_message) {
|
||||
fprintf(stderr, "Error: %s\n", error_message);
|
||||
return 1;
|
||||
}
|
||||
|
||||
puts("Model loaded. Generating text...");
|
||||
char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message);
|
||||
if (error_message) {
|
||||
fprintf(stderr, "Error: %s\n", error_message);
|
||||
return 1;
|
||||
}
|
||||
|
||||
puts("Generated text:");
|
||||
printf("%s%s%s\n", prefix, result, suffix);
|
||||
|
||||
free(result);
|
||||
llama_free(ctx);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue