Added makefile, better error messages
This commit is contained in:
parent
90afd6dfad
commit
828a43d2b3
4 changed files with 58 additions and 19 deletions
5
Makefile
5
Makefile
|
@ -1,5 +1,5 @@
|
|||
# Define the default target now so that it is always the first target
|
||||
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam_search
|
||||
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam_search fill-in-middle
|
||||
|
||||
# Binaries only useful for tests
|
||||
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1
|
||||
|
@ -427,6 +427,9 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
|
|||
beam_search: examples/beam_search/beam_search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
fill-in-middle: examples/fill-in-middle/FIM.c ggml.o llama.o common.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))'
|
||||
BUILD_TARGETS += metal
|
||||
endif
|
||||
|
|
5
examples/fill-in-middle/CMakeLists.txt
Normal file
5
examples/fill-in-middle/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
set(TARGET FIM)
|
||||
add_executable(${TARGET} FIM.c)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
@ -8,15 +8,26 @@ For a quick summary of what's going on here, see issue #2818.
|
|||
*/
|
||||
|
||||
|
||||
static inline struct llama_context* create_codellama_fim_context(const char* model_path) {
|
||||
static inline struct llama_context* codellama_create_fim_context(const char* model_path, const char** error_message) {
|
||||
struct llama_context_params params = llama_context_default_params();
|
||||
struct llama_model* model = llama_load_model_from_file(model_path, params);
|
||||
if (!model) {
|
||||
*error_message = "Failed to load model.";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
struct llama_context* context = llama_new_context_with_model(model, params);
|
||||
if (!context) {
|
||||
*error_message = "Failed to create context.";
|
||||
llama_free_model(model);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
static inline char*
|
||||
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, char** error_message) {
|
||||
codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const char* suffix, size_t n_max_tokens, int n_threads, bool spm, const char** error_message) {
|
||||
|
||||
int num_tokens;
|
||||
llama_token* tokens_end = (llama_token*)malloc(sizeof(llama_token) * n_max_tokens);
|
||||
|
@ -45,48 +56,60 @@ codellama_fill_in_middle(struct llama_context* ctx, const char* prefix, const ch
|
|||
}
|
||||
|
||||
// Append middle token
|
||||
int num_prompt_tokens = (int)(tokens_end - tokens);
|
||||
*tokens_end++ = llama_token_middle(ctx);
|
||||
|
||||
// Evaluate the LM on the prompt.
|
||||
if (llama_eval(ctx, tokens, num_prompt_tokens, 0, n_threads)) {
|
||||
if (llama_eval(ctx, tokens, (int)(tokens_end - tokens), 0, n_threads)) {
|
||||
*error_message = "Failed to evaluate the prompt.";
|
||||
free(tokens);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Generate tokens until n_max_tokens or the <EOT> token is generated.
|
||||
int num_generated_tokens = 0;
|
||||
llama_token* generated_tokens = NULL;
|
||||
size_t num_generated_tokens = 0;
|
||||
int vocab_size = llama_n_vocab(ctx);
|
||||
for (size_t i = 0; i < n_max_tokens; i++) {
|
||||
// Evaluate the LM for a single token
|
||||
llama_token* current_token = generated_tokens + num_generated_tokens;
|
||||
if (llama_eval(ctx, current_token, 1, num_generated_tokens, n_threads)) {
|
||||
// Evaluate the LM for a single token, obtaining the logits and probabilities.
|
||||
if (llama_eval(ctx, &generated_tokens[num_generated_tokens], 1, (int)num_generated_tokens, n_threads)) {
|
||||
*error_message = "Failed to evaluate the prompt.";
|
||||
free(tokens);
|
||||
break;
|
||||
}
|
||||
float* logits = llama_get_logits(ctx);
|
||||
|
||||
if (*current_token == llama_token_eot(ctx)) {
|
||||
// From the logits, select the most likely token.
|
||||
float highest_log_likelihood = -1;
|
||||
llama_token likeliest_token = -1;
|
||||
for (llama_token token_id = 0; token_id < vocab_size; token_id++) {
|
||||
if (logits[token_id] > highest_log_likelihood) {
|
||||
highest_log_likelihood = logits[token_id];
|
||||
likeliest_token = token_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Don't add the token if it's <EOT>.
|
||||
if (likeliest_token == llama_token_eot(ctx)) {
|
||||
break;
|
||||
}
|
||||
|
||||
num_generated_tokens++;
|
||||
// Append the token, so it's there for subsequent evaluations.
|
||||
generated_tokens[num_generated_tokens++] = likeliest_token;
|
||||
}
|
||||
|
||||
// Allocate memory for the final result
|
||||
size_t result_length = 0;
|
||||
size_t result_capacity = 4096;
|
||||
char* result = (char*)malloc(sizeof(char) * 4096);
|
||||
char* result = (char*)malloc(sizeof(char) * result_capacity);
|
||||
if (!result) {
|
||||
*error_message = "Failed to allocate memory for result.";
|
||||
free(tokens);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Translate tokens to string
|
||||
// Translate tokens to string, growing the allocation if it's too small.
|
||||
for (size_t i = 0; i < num_generated_tokens; i++) {
|
||||
int appended = llama_token_to_str(ctx, generated_tokens[i], result, result_capacity - result_length);
|
||||
int appended = llama_token_to_piece(ctx, generated_tokens[i], result, result_capacity - result_length);
|
||||
if (appended < 0) {
|
||||
// retry the token with a larger buffer
|
||||
i--;
|
||||
|
@ -116,20 +139,27 @@ int main(int argc, char** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
struct llama_context* ctx = create_codellama_fim_context(argv[1]);
|
||||
|
||||
char* model = argv[1];
|
||||
char* prefix = argv[2];
|
||||
char* suffix = argv[3];
|
||||
size_t n_max_tokens = atoi(argv[4]);
|
||||
int n_threads = atoi(argv[5]);
|
||||
bool spm = false;
|
||||
char* error_message = NULL;
|
||||
char* result = codellama_fill_in_middle(ctx, argv[2], argv[3], n_max_tokens, n_threads, spm, &error_message);
|
||||
const char* error_message = NULL;
|
||||
|
||||
struct llama_context* ctx = codellama_create_fim_context(model, &error_message);
|
||||
if (error_message) {
|
||||
fprintf(stderr, "Error: %s\n", error_message);
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("%s\n", result);
|
||||
char* result = codellama_fill_in_middle(ctx, prefix, suffix, n_max_tokens, n_threads, spm, &error_message);
|
||||
if (error_message) {
|
||||
fprintf(stderr, "Error: %s\n", error_message);
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("%s%s%s\n", prefix, result, suffix);
|
||||
|
||||
llama_free(ctx);
|
||||
}
|
||||
|
|
|
@ -960,6 +960,7 @@ struct llama_vocab {
|
|||
id linefeed_id = 13;
|
||||
|
||||
// codellama FIM special tokens
|
||||
// TODO: load these from the vocabulary.
|
||||
id special_prefix_id = 32007;
|
||||
id special_middle_id = 32009;
|
||||
id special_suffix_id = 32008;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue