Added makefile, better error messages

This commit is contained in:
apaz-cli 2023-08-30 10:59:52 -05:00
parent 90afd6dfad
commit 828a43d2b3
4 changed files with 58 additions and 19 deletions

View file

@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam_search
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam_search fill-in-middle
# Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1
@ -427,6 +427,9 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
beam_search: examples/beam_search/beam_search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
fill-in-middle: examples/fill-in-middle/FIM.c ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))'
BUILD_TARGETS += metal
endif

View file

@ -0,0 +1,5 @@
set(TARGET FIM)
add_executable(${TARGET} FIM.c)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -8,15 +8,26 @@ For a quick summary of what's going on here, see issue #2818.
*/
static inline struct llama_context* create_codellama_fim_context(const char* model_path) {
static inline struct llama_context* codellama_create_fim_context(const char* model_path, const char** error_message) {
struct llama_context_params params = llama_context_default_params();
struct llama_model* model = llama_load_model_from_file(model_path, params);
if (!model) {
*error_message = "Failed to load model.";
return NULL;
}
struct llama_context* context = llama_new_context_with_model(model, params);
if (!context) {
*error_message = "Failed to create context.";
llama_free_model(model);
return NULL;
}
return context;
}
static inline char*
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, char** error_message) {
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) {
int num_tokens;
llama_token* tokens_end = (llama_token*)malloc(sizeof(llama_token) * n_max_tokens);
@ -45,48 +56,60 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch
}
// Append middle token
int num_prompt_tokens = (int)(tokens_end - tokens);
*tokens_end++ = llama_token_middle(ctx);
// Evaluate the LM on the prompt.
if (llama_eval(ctx, tokens, num_prompt_tokens, 0, n_threads)) {
if (llama_eval(ctx, tokens, (int)(tokens_end - tokens), 0, n_threads)) {
*error_message = "Failed to evaluate the prompt.";
free(tokens);
return NULL;
}
// Generate tokens until n_max_tokens or the <EOT> token is generated.
int num_generated_tokens = 0;
llama_token* generated_tokens = NULL;
size_t num_generated_tokens = 0;
int vocab_size = llama_n_vocab(ctx);
for (size_t i = 0; i < n_max_tokens; i++) {
// Evaluate the LM for a single token
llama_token* current_token = generated_tokens + num_generated_tokens;
if (llama_eval(ctx, current_token, 1, num_generated_tokens, n_threads)) {
// Evaluate the LM for a single token, obtaining the logits and probabilities.
if (llama_eval(ctx, &generated_tokens[num_generated_tokens], 1, (int)num_generated_tokens, n_threads)) {
*error_message = "Failed to evaluate the prompt.";
free(tokens);
break;
}
float* logits = llama_get_logits(ctx);
if (*current_token == llama_token_eot(ctx)) {
// From the logits, select the most likely token.
float highest_log_likelihood = -1;
llama_token likeliest_token = -1;
for (llama_token token_id = 0; token_id < vocab_size; token_id++) {
if (logits[token_id] > highest_log_likelihood) {
highest_log_likelihood = logits[token_id];
likeliest_token = token_id;
}
}
// Don't add the token if it's <EOT>.
if (likeliest_token == llama_token_eot(ctx)) {
break;
}
num_generated_tokens++;
// Append the token, so it's there for subsequent evaluations.
generated_tokens[num_generated_tokens++] = likeliest_token;
}
// Allocate memory for the final result
size_t result_length = 0;
size_t result_capacity = 4096;
char* result = (char*)malloc(sizeof(char) * 4096);
char* result = (char*)malloc(sizeof(char) * result_capacity);
if (!result) {
*error_message = "Failed to allocate memory for result.";
free(tokens);
return NULL;
}
// Translate tokens to string
// Translate tokens to string, growing the allocation if it's too small.
for (size_t i = 0; i < num_generated_tokens; i++) {
int appended = llama_token_to_str(ctx, generated_tokens[i], result, result_capacity - result_length);
int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length);
if (appended < 0) {
// retry the token with a larger buffer
i--;
@ -116,20 +139,27 @@ int main(int argc, char** argv) {
return 1;
}
struct llama_context* ctx = create_codellama_fim_context(argv[1]);
char* model = argv[1];
char* prefix = argv[2];
char* suffix = argv[3];
size_t n_max_tokens = atoi(argv[4]);
int n_threads = atoi(argv[5]);
bool spm = false;
char* error_message = NULL;
char* result = codellama_fill_in_middle(ctx, argv[2], argv[3], n_max_tokens, n_threads, spm, &error_message);
const char* error_message = NULL;
struct llama_context* ctx = codellama_create_fim_context(model, &error_message);
if (error_message) {
fprintf(stderr, "Error: %s\n", error_message);
return 1;
}
printf("%s\n", result);
char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message);
if (error_message) {
fprintf(stderr, "Error: %s\n", error_message);
return 1;
}
printf("%s%s%s\n", prefix, result, suffix);
llama_free(ctx);
}

View file

@ -960,6 +960,7 @@ struct llama_vocab {
id linefeed_id = 13;
// codellama FIM special tokens
// TODO: load these from the vocabulary.
id special_prefix_id = 32007;
id special_middle_id = 32009;
id special_suffix_id = 32008;