Debugging crash.
This commit is contained in:
parent
142d79b459
commit
314c29cc32
1 changed files with 27 additions and 7 deletions
|
@ -1,4 +1,5 @@
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include "../../llama.h"
|
#include "../../llama.h"
|
||||||
|
|
||||||
|
@ -8,7 +9,8 @@ For a quick summary of what's going on here, see issue #2818.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
static inline struct llama_context* codellama_create_fim_context(const char* model_path, const char** error_message) {
|
static inline struct llama_context*
|
||||||
|
codellama_create_fim_context(const char* model_path, const char** error_message) {
|
||||||
struct llama_context_params params = llama_context_default_params();
|
struct llama_context_params params = llama_context_default_params();
|
||||||
struct llama_model* model = llama_load_model_from_file(model_path, params);
|
struct llama_model* model = llama_load_model_from_file(model_path, params);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
|
@ -30,7 +32,9 @@ static inline char*
|
||||||
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) {
|
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) {
|
||||||
|
|
||||||
int num_tokens;
|
int num_tokens;
|
||||||
llama_token* tokens_end = (llama_token*)malloc(sizeof(llama_token) * n_max_tokens);
|
size_t combined_len = strlen(prefix) + strlen(suffix) + 3;
|
||||||
|
size_t initial_size = sizeof(llama_token) * combined_len;
|
||||||
|
llama_token* tokens_end = (llama_token*)malloc(initial_size);
|
||||||
llama_token* tokens = tokens_end;
|
llama_token* tokens = tokens_end;
|
||||||
if (!tokens) {
|
if (!tokens) {
|
||||||
*error_message = "Failed to allocate memory for tokens.";
|
*error_message = "Failed to allocate memory for tokens.";
|
||||||
|
@ -58,15 +62,28 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch
|
||||||
// Append middle token
|
// Append middle token
|
||||||
*tokens_end++ = llama_token_middle(ctx);
|
*tokens_end++ = llama_token_middle(ctx);
|
||||||
|
|
||||||
|
// Grow to accommodate the prompt and the max amount of generated tokens
|
||||||
|
size_t prompt_len = (size_t)(tokens_end - tokens);
|
||||||
|
size_t min_len = (prompt_len + n_max_tokens);
|
||||||
|
if (min_len > combined_len) {
|
||||||
|
llama_token* new_tokens = (llama_token*)realloc(tokens, sizeof(llama_token) * min_len);
|
||||||
|
if (!new_tokens) {
|
||||||
|
*error_message = "Failed to allocate memory for tokens.";
|
||||||
|
free(tokens);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
tokens = new_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
// Evaluate the LM on the prompt.
|
// Evaluate the LM on the prompt.
|
||||||
if (llama_eval(ctx, tokens, (int)(tokens_end - tokens), 0, n_threads)) {
|
if (llama_eval(ctx, tokens, prompt_len, 0, n_threads)) {
|
||||||
*error_message = "Failed to evaluate the prompt.";
|
*error_message = "Failed to evaluate the prompt.";
|
||||||
free(tokens);
|
free(tokens);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate tokens until n_max_tokens or the <EOT> token is generated.
|
// Generate tokens until n_max_tokens or the <EOT> token is generated.
|
||||||
llama_token* generated_tokens = NULL;
|
llama_token* generated_tokens = tokens + prompt_len;
|
||||||
size_t num_generated_tokens = 0;
|
size_t num_generated_tokens = 0;
|
||||||
int vocab_size = llama_n_vocab(ctx);
|
int vocab_size = llama_n_vocab(ctx);
|
||||||
for (size_t i = 0; i < n_max_tokens; i++) {
|
for (size_t i = 0; i < n_max_tokens; i++) {
|
||||||
|
@ -111,8 +128,7 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch
|
||||||
for (size_t i = 0; i < num_generated_tokens; i++) {
|
for (size_t i = 0; i < num_generated_tokens; i++) {
|
||||||
int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length);
|
int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length);
|
||||||
if (appended < 0) {
|
if (appended < 0) {
|
||||||
// retry the token with a larger buffer
|
i--; // retry the token with a larger buffer
|
||||||
i--;
|
|
||||||
size_t new_capacity = result_capacity * 2;
|
size_t new_capacity = result_capacity * 2;
|
||||||
char* new_result = (char*)realloc(result, sizeof(char) * new_capacity);
|
char* new_result = (char*)realloc(result, sizeof(char) * new_capacity);
|
||||||
if (!new_result) {
|
if (!new_result) {
|
||||||
|
@ -142,24 +158,28 @@ int main(int argc, char** argv) {
|
||||||
char* model = argv[1];
|
char* model = argv[1];
|
||||||
char* prefix = argv[2];
|
char* prefix = argv[2];
|
||||||
char* suffix = argv[3];
|
char* suffix = argv[3];
|
||||||
size_t n_max_tokens = atoi(argv[4]);
|
size_t n_max_tokens = atoi(argv[4]) > 0 ? atoi(argv[4]) : 64;
|
||||||
int n_threads = atoi(argv[5]);
|
int n_threads = atoi(argv[5]);
|
||||||
bool spm = false;
|
bool spm = false;
|
||||||
const char* error_message = NULL;
|
const char* error_message = NULL;
|
||||||
|
|
||||||
|
puts("Loading the model. This could take quite a while...");
|
||||||
struct llama_context* ctx = codellama_create_fim_context(model, &error_message);
|
struct llama_context* ctx = codellama_create_fim_context(model, &error_message);
|
||||||
if (error_message) {
|
if (error_message) {
|
||||||
fprintf(stderr, "Error: %s\n", error_message);
|
fprintf(stderr, "Error: %s\n", error_message);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
puts("Model loaded. Generating text...");
|
||||||
char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message);
|
char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message);
|
||||||
if (error_message) {
|
if (error_message) {
|
||||||
fprintf(stderr, "Error: %s\n", error_message);
|
fprintf(stderr, "Error: %s\n", error_message);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
puts("Generated text:");
|
||||||
printf("%s%s%s\n", prefix, result, suffix);
|
printf("%s%s%s\n", prefix, result, suffix);
|
||||||
|
|
||||||
|
free(result);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue