diff --git a/Makefile b/Makefile index f20a46787..9747e951c 100644 --- a/Makefile +++ b/Makefile @@ -39,8 +39,8 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -I./include -I./include/CL -I./otherarch -I./otherarch/tools -Ofast -DNDEBUG -std=c11 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE -CXXFLAGS = -I. -I./common -I./include -I./include/CL -I./otherarch -I./otherarch/tools -Ofast -DNDEBUG -std=c++11 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE +CFLAGS = -I. -I./include -I./include/CL -I./otherarch -I./otherarch/tools -O3 -DNDEBUG -std=c11 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE +CXXFLAGS = -I. -I./common -I./include -I./include/CL -I./otherarch -I./otherarch/tools -O3 -DNDEBUG -std=c++11 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE LDFLAGS = # these are used on windows, to build some libraries with extra old device compatibility @@ -125,17 +125,7 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686)) endif endif endif -ifneq ($(filter ppc64%,$(UNAME_M)),) - POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) - ifneq (,$(findstring POWER9,$(POWER9_M))) - CFLAGS += -mcpu=power9 - CXXFLAGS += -mcpu=power9 - endif - # Require c++23's std::byteswap for big-endian support. - ifeq ($(UNAME_M),ppc64) - CXXFLAGS += -std=c++23 -DGGML_BIG_ENDIAN - endif -endif + ifndef LLAMA_NO_ACCELERATE # Mac M1 - include Accelerate framework. # `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time). @@ -193,10 +183,7 @@ ifdef LLAMA_CUDA_MMQ_Y NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y) else NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64 -endif # LLAMA_CUDA_MMQ_Y -#ifdef LLAMA_CUDA_CUBLAS -# NVCCFLAGS += -DGGML_CUDA_CUBLAS -#endif # LLAMA_CUDA_CUBLAS +endif ifdef LLAMA_CUDA_CCBIN NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) endif @@ -240,7 +227,6 @@ ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-l endif # LLAMA_HIPBLAS - ifdef LLAMA_METAL CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CXXFLAGS += -DGGML_USE_METAL @@ -254,20 +240,30 @@ endif # LLAMA_METAL ifneq ($(filter aarch64%,$(UNAME_M)),) # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) - CFLAGS += - CXXFLAGS += + CFLAGS += -mcpu=native + CXXFLAGS += -mcpu=native endif ifneq ($(filter armv6%,$(UNAME_M)),) # Raspberry Pi 1, Zero - CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access + CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access + CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access endif ifneq ($(filter armv7%,$(UNAME_M)),) # Raspberry Pi 2 - CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations + CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations + CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations endif ifneq ($(filter armv8%,$(UNAME_M)),) # Raspberry Pi 3, 4, Zero 2 (32-bit) - CFLAGS += -mfp16-format=ieee -mno-unaligned-access + CFLAGS += -mfp16-format=ieee -mno-unaligned-access + CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access +endif +ifneq ($(filter ppc64%,$(UNAME_M)),) + POWER9_M := $(shell grep "POWER9" /proc/cpuinfo) + ifneq (,$(findstring POWER9,$(POWER9_M))) + CFLAGS += -mcpu=power9 + CXXFLAGS += -mcpu=power9 + endif endif diff --git a/common/sampling.cpp b/common/sampling.cpp index f4e76df31..8e45909f1 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -149,11 +149,12 @@ static void sampler_queue( } } -llama_token llama_sampling_sample( +static llama_token llama_sampling_sample_impl( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - const int idx) { + const int idx, + bool is_resampling) { // Add a parameter to indicate if we are resampling const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -173,8 +174,17 @@ llama_token llama_sampling_sample( llama_token id = 0; + // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + // Declare original_logits at the beginning of the function scope + std::vector original_logits; + + if (!is_resampling) { + // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this. + original_logits = std::vector(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); + } + // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; @@ -193,12 +203,14 @@ llama_token llama_sampling_sample( } // apply penalties - if (!prev.empty()) { + const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; llama_sample_repetition_penalties(ctx_main, &cur_p, - prev.data() + prev.size() - penalty_last_n, - penalty_last_n, penalty_repeat, penalty_freq, penalty_present); + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { @@ -210,7 +222,8 @@ llama_token llama_sampling_sample( } } - if (ctx_sampling->grammar != NULL) { + // If we are in the resampling phase, apply grammar checks before sampling logic + if (is_resampling && ctx_sampling->grammar != NULL) { llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); } @@ -252,9 +265,40 @@ llama_token llama_sampling_sample( } } + if (ctx_sampling->grammar != NULL && !is_resampling) { + // Create an array with a single token data element for the sampled id + llama_token_data single_token_data = {id, logits[id], 0.0f}; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + + // Apply grammar constraints to the single token + llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar); + + // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + + // If the token is not valid according to the grammar, perform resampling + if (!is_valid) { + LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); + + // Restore logits from the copy + std::copy(original_logits.begin(), original_logits.end(), logits); + + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + } + } + return id; } +llama_token llama_sampling_sample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + // Call the implementation function with is_resampling set to false by default + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); +} + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index fdfa9eed1..f16ef97e3 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -36,6 +36,9 @@ typedef struct llama_sampling_params { float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens + + std::vector penalty_prompt_tokens; + bool use_penalty_prompt_tokens = false; } llama_sampling_params; // general sampler context diff --git a/examples/server/README.md b/examples/server/README.md index 0751b9612..f1e586a1c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -148,6 +148,8 @@ node index.js `frequency_penalty`: Repeat alpha frequency penalty (default: 0.0, 0.0 = disabled); + `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens (default: `null` = use the original `prompt`). + `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0). `mirostat_tau`: Set the Mirostat target entropy, parameter tau (default: 5.0). diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 17f4c07a7..c8bcc2a04 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -762,6 +762,42 @@ struct llama_server_context slot->prompt = ""; } + slot->sparams.penalty_prompt_tokens.clear(); + slot->sparams.use_penalty_prompt_tokens = false; + const auto &penalty_prompt = data.find("penalty_prompt"); + if (penalty_prompt != data.end()) + { + if (penalty_prompt->is_string()) + { + const auto penalty_prompt_string = penalty_prompt->get(); + auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false); + slot->sparams.penalty_prompt_tokens.swap(penalty_tokens); + if (slot->params.n_predict > 0) + { + slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict); + } + slot->sparams.use_penalty_prompt_tokens = true; + } + else if (penalty_prompt->is_array()) + { + const auto n_tokens = penalty_prompt->size(); + slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict)); + const int n_vocab = llama_n_vocab(model); + for (const auto &penalty_token : *penalty_prompt) + { + if (penalty_token.is_number_integer()) + { + const auto tok = penalty_token.get(); + if (tok >= 0 && tok < n_vocab) + { + slot->sparams.penalty_prompt_tokens.push_back(tok); + } + } + } + slot->sparams.use_penalty_prompt_tokens = true; + } + } + slot->sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false)) @@ -993,6 +1029,12 @@ struct llama_server_context slot.generated_text += token_str; slot.has_next_token = true; + if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) + { + // we can change penalty_prompt_tokens because it is always created from scratch each request + slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); + } + // check if there is incomplete UTF-8 character at the end bool incomplete = false; for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) @@ -1184,6 +1226,8 @@ struct llama_server_context {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, + {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, + {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c26c52d72..42e00eebf 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6718,8 +6718,7 @@ void * ggml_cuda_host_malloc(size_t size) { void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { - // The allocation error can be bypassed. A null ptr will assigned out of this function. - // This can fixed the OOM error in WSL. + // clear the error cudaGetLastError(); fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size/1024.0/1024.0, cudaGetErrorString(err)); @@ -7926,12 +7925,16 @@ static void ggml_cuda_op_mul_mat( if (id != 0) { row_low[id] = ne01*g_tensor_split[id]; - row_low[id] -= row_low[id] % rounding; + if (row_low[id] < ne01) { + row_low[id] -= row_low[id] % rounding; + } } if (id != g_device_count - 1) { row_high[id] = ne01*g_tensor_split[id + 1]; - row_high[id] -= row_high[id] % rounding; + if (row_high[id] < ne01) { + row_high[id] -= row_high[id] % rounding; + } } } } @@ -9666,12 +9669,14 @@ ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { // host buffer type static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - CUDA_CHECK(cudaFreeHost(buffer->context)); + ggml_cuda_host_free(buffer->context); } static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - void * ptr; - CUDA_CHECK(cudaMallocHost(&ptr, size)); + void * ptr = ggml_cuda_host_malloc(size); + if (ptr == nullptr) { + return nullptr; + } // FIXME: this is a hack to avoid having to implement a new buffer type ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); diff --git a/klite.embd b/klite.embd index 8c71698f9..32392caa2 100644 --- a/klite.embd +++ b/klite.embd @@ -7649,7 +7649,7 @@ Current version: 101 selected_workers = []; localsettings.opmode = 1; } - restart_new_game(); + restart_new_game(true, document.getElementById("keep_memory").checked); hide_popups(); } @@ -7802,7 +7802,7 @@ Current version: 101 horde_poll_nearly_completed = false; } - function restart_new_game(save = true) { + function restart_new_game(save = true, keep_memory = false) { idle_timer = 0; gametext_arr = []; redo_arr = []; @@ -7816,15 +7816,7 @@ Current version: 101 synchro_pending_stream = ""; waiting_for_autosummary = false; last_reply_was_empty = false; - current_memory = ""; - current_anote = ""; - current_wi = []; pending_context_preinjection = ""; - extrastopseq = ""; - anote_strength = 320; - wi_searchdepth = 0; - wi_insertlocation = 0; - current_anotetemplate = "[Author's note: <|>]"; document.getElementById("input_text").value = ""; document.getElementById("cht_inp").value = ""; chat_resize_input(); @@ -7833,9 +7825,20 @@ Current version: 101 localsettings.adventure_is_action = false; prev_hl_chunk = null; last_token_budget = ""; - last_known_filename = "saved_story.json"; groupchat_removals = []; welcome = ""; + last_known_filename = "saved_story.json"; + if (!keep_memory) + { + current_memory = ""; + current_anote = ""; + current_wi = []; + extrastopseq = ""; + anote_strength = 320; + wi_searchdepth = 0; + wi_insertlocation = 0; + current_anotetemplate = "[Author's note: <|>]"; + } render_gametext(save); //necessary to trigger an autosave to wipe out current story in case they exit browser after newgame. } @@ -7849,6 +7852,7 @@ Current version: 101 restart_new_game(); display_settings(); confirm_settings(); + document.getElementById("keep_memory").checked = false; },null); } @@ -12004,7 +12008,18 @@ Current version: 101
Unsaved data will be lost.

-
Keep AI Selected?
+
+
+
+ Keep AI Selected? + +
+
+ Keep Memory and World Info? + +
+
+

diff --git a/llama.cpp b/llama.cpp index d74b89008..a85d835b5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1182,21 +1182,27 @@ static std::string llama_token_to_str(const struct llama_context * ctx, llama_to } static ggml_backend_buffer_type_t llama_default_buffer_type(int n_gpu_layers) { + ggml_backend_buffer_type_t buft = nullptr; + #ifdef GGML_USE_METAL if (n_gpu_layers > 0) { - return ggml_backend_metal_buffer_type(); + buft = ggml_backend_metal_buffer_type(); } #elif defined(GGML_USE_CUBLAS) && defined(LLAMA_GGML_BACKEND_CUDA_TEST) if (n_gpu_layers > 0) { - return ggml_backend_cuda_buffer_type(0); + buft = ggml_backend_cuda_buffer_type(0); } #elif defined(GGML_USE_CUBLAS) - return ggml_backend_cuda_host_buffer_type(); + buft = ggml_backend_cuda_host_buffer_type(); #elif defined(GGML_USE_CPU_HBM) - return ggml_backend_cpu_hbm_buffer_type(); + buft = ggml_backend_cpu_hbm_buffer_type(); #endif - return ggml_backend_cpu_buffer_type(); + if (buft == nullptr) { + buft = ggml_backend_cpu_buffer_type(); + } + + return buft; GGML_UNUSED(n_gpu_layers); }