Implementation of a sequence repetition penalty
This commit is contained in:
parent
875fb42871
commit
11fa3dfd69
14 changed files with 2192 additions and 3 deletions
|
@ -96,6 +96,8 @@ option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging"
|
|||
option(LLAMA_MPI "llama: use MPI" OFF)
|
||||
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
|
||||
|
||||
option(LLAMA_SEQREP_SAMPLER "llama: build with support for seqrep sampler" ON)
|
||||
|
||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ON)
|
||||
|
|
13
Makefile
13
Makefile
|
@ -2,7 +2,7 @@
|
|||
BUILD_TARGETS = \
|
||||
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
||||
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
|
||||
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
|
||||
speculative infill benchmark-matmult parallel finetune export-lora simple-inference tests/test-c.o
|
||||
|
||||
# Binaries only useful for tests
|
||||
TEST_TARGETS = \
|
||||
|
@ -559,6 +559,14 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
|
|||
train.o: common/train.cpp common/train.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
COMMON_H_DEFS += common/seqrep-sampler.h
|
||||
COMMON_DEPS += seqrep-sampler.o
|
||||
|
||||
seqrep-sampler.o: common/seqrep-sampler.cpp $(COMMON_H_DEPS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
endif
|
||||
|
||||
libllama.so: llama.o ggml.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
|
||||
|
||||
|
@ -581,6 +589,9 @@ infill: examples/infill/infill.cpp ggml.o llama.o $(C
|
|||
simple: examples/simple/simple.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
simple-inference: examples/simple-inference/simple-inference.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
batched: examples/batched/batched.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||
|
||||
|
|
|
@ -111,6 +111,8 @@ pub fn build(b: *std.build.Builder) !void {
|
|||
var make = try Maker.init(b);
|
||||
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
|
||||
|
||||
try make.addFlag("-DLLAMA_NO_SEQREP_SAMPLER");
|
||||
|
||||
const ggml = make.obj("ggml", "ggml.c");
|
||||
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
|
||||
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
|
||||
|
|
|
@ -54,6 +54,12 @@ add_library(${TARGET} STATIC
|
|||
train.cpp
|
||||
)
|
||||
|
||||
if (LLAMA_SEQREP_SAMPLER)
|
||||
target_sources(${TARGET} PRIVATE seqrep-sampler.h seqrep-sampler.cpp)
|
||||
else()
|
||||
add_compile_definitions(LLAMA_NO_SEQREP_SAMPLER)
|
||||
endif()
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
#include "seqrep-sampler.h"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
@ -335,6 +339,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
sparams.penalty_present = std::stof(argv[i]);
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
if (std::strcmp(argv[i], "help") == 0) {
|
||||
seqrep_sampler_help();
|
||||
exit(0);
|
||||
}
|
||||
llama_sampler_seqrep_params sr_params;
|
||||
seqrep_sampler_params_init(&sr_params);
|
||||
if (!seqrep_sampler_params_parse(argv[i], &sr_params)) {
|
||||
seqrep_sampler_help();
|
||||
exit(1);
|
||||
}
|
||||
sparams.seqrep_params.push_back(sr_params);
|
||||
#endif
|
||||
} else if (arg == "--mirostat") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -764,6 +786,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
|
||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
|
||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
printf(" -seqrep CFG, --seqrep-penalty CFG\n");
|
||||
printf(" add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
|
||||
#endif
|
||||
printf(" --mirostat N use Mirostat sampling.\n");
|
||||
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
|
||||
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
|
||||
|
|
|
@ -103,7 +103,8 @@ llama_token llama_sampling_sample(
|
|||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx) {
|
||||
const int idx,
|
||||
const std::vector<llama_token> & all_last_tokens) {
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
@ -155,6 +156,13 @@ llama_token llama_sampling_sample(
|
|||
prev.data() + prev.size() - penalty_last_n,
|
||||
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
for (auto & sr_params : params.seqrep_params) {
|
||||
if ((sr_params.flags & LLAMA_SEQREP_REWIND_MODE) != 0) continue;
|
||||
llama_sample_seqrep_penalty(ctx_main, &cur_p, all_last_tokens, &sr_params);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
#include "seqrep-sampler.h"
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
@ -35,6 +39,11 @@ typedef struct llama_sampling_params {
|
|||
float cfg_scale = 1.f; // how strong is guidance
|
||||
|
||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
std::vector<llama_sampler_seqrep_params> seqrep_params;
|
||||
#endif
|
||||
|
||||
} llama_sampling_params;
|
||||
|
||||
// general sampler context
|
||||
|
@ -101,7 +110,8 @@ llama_token llama_sampling_sample(
|
|||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
int idx = 0);
|
||||
int idx = 0,
|
||||
const std::vector<llama_token> & all_last_tokens = {});
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
|
|
1014
common/seqrep-sampler.cpp
Normal file
1014
common/seqrep-sampler.cpp
Normal file
File diff suppressed because it is too large
Load diff
196
common/seqrep-sampler.h
Normal file
196
common/seqrep-sampler.h
Normal file
|
@ -0,0 +1,196 @@
|
|||
#pragma once
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
enum llama_sampler_seqrep_flags {
|
||||
// Tolerance charges can't be used consecutively.
|
||||
LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE = (1 << 0),
|
||||
|
||||
// Tolerance charges can't be used before the first actual match.
|
||||
LLAMA_SEQREP_TOLERANCE_NO_FIRST = (1 << 1),
|
||||
|
||||
// When applying the length penalty, use the length of the longest observed
|
||||
// sequence matching the token rather than the total length of
|
||||
// sequences matching the token. In other words, if we find a sequence
|
||||
// of length 3 and a sequence of length 4 continued by token 69 then
|
||||
// with this flag on we penalize based on length 4, with it off we
|
||||
// penalize based on length 7 (3 + 4).
|
||||
LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN = (1 << 2),
|
||||
|
||||
// Apply an absolute penalty rather than dividing the logit by the penalty.
|
||||
LLAMA_SEQREP_ABSOLUTE_PENALTY = (1 << 3),
|
||||
|
||||
// Rewind to cut off the head of sequences rather than the end.
|
||||
// Ignored when min_length < 2.
|
||||
// Since it wouldn't make sense to rewind and then let sampling pick
|
||||
// the same token again, penalty values and mid_word_scale have no
|
||||
// effect.
|
||||
LLAMA_SEQREP_REWIND_MODE = (1 << 4),
|
||||
|
||||
// When rewinding, skip past whitespace and punctuation. For example,
|
||||
// if the matched sequence was "<NL>'hello" then we will rewind to the
|
||||
// token starting with 'h' and ban it.
|
||||
LLAMA_SEQREP_REWIND_SKIP_WS_PUNCT = (1 << 5),
|
||||
|
||||
// Rewind to the shortest matching sequence of at least min_length rather than the longest.
|
||||
LLAMA_SEQREP_REWIND_USE_SHORTEST_MATCH = (1 << 6),
|
||||
|
||||
// Rewinding requires a word boundary. Only has an effect when rewind_seek_word_boundary isn't 0.
|
||||
LLAMA_SEQREP_REWIND_REQUIRE_WBOUND = (1 << 7),
|
||||
|
||||
// Persisted bans are only applied if at a word bound.
|
||||
LLAMA_SEQREP_REWIND_PERSIST_REQUIRE_WBOUND = (1 << 8),
|
||||
};
|
||||
|
||||
typedef struct llama_sampler_seqrep_params {
|
||||
// The minimum length of a matching sequence of tokens. When this is < 2 then
|
||||
// the sampler works in single token mode and tolerance is ignored.
|
||||
size_t min_length;
|
||||
|
||||
// Maximum length for a matching sequence of tokens.
|
||||
size_t max_length;
|
||||
|
||||
// Starting offset for matching against the end of the sequence. This can be used
|
||||
// to only match against sequences in the initial prompt, for example. Matching
|
||||
// starts at the offset and moves toward the beginning of the list.
|
||||
// Use 0 for penultimate token when min_length > 1 otherwise 0 for last token.
|
||||
size_t start_offset;
|
||||
|
||||
// Window of last tokens to consider, starting from the end. < 0 means
|
||||
// the whole list.
|
||||
int last_n;
|
||||
|
||||
// Flags based on llama_sampler_seqrep_flags enum values ORed together.
|
||||
int flags;
|
||||
|
||||
// Tolerance for non-matching tokens in a sequence.
|
||||
float tolerance;
|
||||
|
||||
// Flat penalty applied to the token that can continue a repeated sequence.
|
||||
float presence_penalty;
|
||||
|
||||
// Scaling penalty applied to the token that can continue a repeated sequence.
|
||||
// The penalty is multiplied by the total length of sequences that are continued by this token unless
|
||||
// the PENALIZE_LENGTH_MAX_SEEN is set.
|
||||
float length_penalty;
|
||||
|
||||
// Scale for penalizing tokens from repeated sequences that aren't at/form a word boundary.
|
||||
float mid_word_scale;
|
||||
|
||||
// Tolerance credit per real match. I.E. .5 means +1 tolerance per 2 matched tokens.
|
||||
float tolerance_match_credit;
|
||||
|
||||
// Caps tolerance at the specified value. Only meaningful when tolerance_match_credit > 0
|
||||
float tolerance_cap;
|
||||
|
||||
// Ensure the sequence is at least the specified length in rewind mode after
|
||||
// whitespace skipping and other modifications.
|
||||
size_t rewind_min_length;
|
||||
|
||||
// When rewinding, try to find a word boundary within the specified distance, starting with tokens earlier than the rewind point.
|
||||
size_t rewind_seek_word_boundary;
|
||||
|
||||
// A position is limited to the specified number of rewinds. When the limit is exceeded, future rewinds cannot target it or earlier tokens.
|
||||
size_t rewind_max_visits;
|
||||
|
||||
// Tokens banned by rewind remain banned for an additional number of positions equal to the value. i.e. setting this to 1 would mean the token is banned for 2 positions.
|
||||
size_t rewind_persist_bans;
|
||||
|
||||
// Number of tokens from the sequence to ban when rewinding.
|
||||
size_t rewind_ban_length;
|
||||
|
||||
std::vector<std::wregex> include_re;
|
||||
std::vector<std::wregex> exclude_re;
|
||||
} llama_sampler_seqrep_params;
|
||||
|
||||
enum seqrep_check_word_flags {
|
||||
SEQREP_CW_START_IS_WBOUND = 1 << 0,
|
||||
SEQREP_CW_END_IS_WBOUND = 1 << 1,
|
||||
SEQREP_CW_ALL_WS_PUNCT = 1 << 2,
|
||||
SEQREP_CW_START_IS_INVALID = 1 << 3, // Start of token is invalid/incomplete UTF8
|
||||
SEQREP_CW_END_IS_INVALID = 1 << 4 // End of token is invalid/incomplete UTF8
|
||||
};
|
||||
|
||||
|
||||
struct seqrep_logit_info {
|
||||
const int n_vocab;
|
||||
std::vector<llama_token_data> token_data;
|
||||
|
||||
seqrep_logit_info(llama_context * ctx, const size_t k, const int32_t ith);
|
||||
|
||||
const std::vector<llama_token_data> & get_token_data(void);
|
||||
|
||||
llama_token_data get_token_id(const llama_token token_id) const;
|
||||
|
||||
void rebuild(llama_context *ctx, const size_t k, const int32_t ith);
|
||||
|
||||
void populate_logits(float * logits);
|
||||
|
||||
// Yoinked from beam search code.
|
||||
// Return top k token_data by logit.
|
||||
std::vector<llama_token_data> top_k(const float * const logits, const size_t k);
|
||||
|
||||
};
|
||||
|
||||
struct seqrep_rewind_slot {
|
||||
size_t count;
|
||||
std::vector<llama_token> tokens;
|
||||
struct llama_sampling_context * ctx_sampling = nullptr;
|
||||
};
|
||||
|
||||
struct seqrep_rewind_state {
|
||||
const size_t n_vocab;
|
||||
const size_t n_ctx;
|
||||
const size_t k;
|
||||
|
||||
std::vector<seqrep_logit_info> logit_slots;
|
||||
std::vector<seqrep_rewind_slot> rewind_slots;
|
||||
|
||||
seqrep_rewind_state(
|
||||
const size_t n_vocab,
|
||||
const size_t n_ctx,
|
||||
const size_t k = 2000);
|
||||
|
||||
struct seqrep_rewind_slot & get_rewind_slot(const size_t idx);
|
||||
|
||||
void set_logits_slot(llama_context * ctx, const size_t idx, const int32_t ith = 0);
|
||||
|
||||
void populate_logits(llama_context * ctx, const size_t idx, const int32_t ith = 0);
|
||||
|
||||
};
|
||||
|
||||
// Sequence repetition penalty with semi-fuzzy matching. Note: Handles the last_n window itself.
|
||||
size_t llama_sample_seqrep_penalty(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
const llama_sampler_seqrep_params * params);
|
||||
|
||||
int llama_seqrep_check_word(
|
||||
const struct llama_context * ctx,
|
||||
const llama_token token,
|
||||
std::vector<char> & buf);
|
||||
|
||||
size_t llama_seqrep_handle_rewind(
|
||||
struct llama_context * ctx,
|
||||
struct seqrep_rewind_state & rewind_state,
|
||||
const std::vector<llama_token> & generated_tokens,
|
||||
const size_t n_generated,
|
||||
const std::vector<llama_token> & prompt_tokens,
|
||||
const std::vector<llama_sampler_seqrep_params> & params_list,
|
||||
size_t * high_water_mark,
|
||||
const int32_t ith = 0);
|
||||
|
||||
void seqrep_sampler_help();
|
||||
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params);
|
||||
void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params);
|
||||
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params);
|
||||
struct llama_sampler_seqrep_params llama_seqrep_merge_params(
|
||||
const std::vector<llama_sampler_seqrep_params> & params_list,
|
||||
const int and_flags,
|
||||
const int not_flags);
|
|
@ -30,6 +30,7 @@ else()
|
|||
add_subdirectory(quantize-stats)
|
||||
add_subdirectory(save-load-state)
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(simple-inference)
|
||||
add_subdirectory(speculative)
|
||||
add_subdirectory(train-text-from-scratch)
|
||||
if (LLAMA_METAL)
|
||||
|
|
|
@ -415,6 +415,11 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
for (auto & sr_params : sparams.seqrep_params) {
|
||||
seqrep_sampler_params_dump(&sr_params);
|
||||
}
|
||||
#endif
|
||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
|
|
8
examples/simple-inference/CMakeLists.txt
Normal file
8
examples/simple-inference/CMakeLists.txt
Normal file
|
@ -0,0 +1,8 @@
|
|||
set(TARGET simple-inference)
|
||||
add_executable(${TARGET} simple-inference.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
819
examples/simple-inference/simple-inference.cpp
Normal file
819
examples/simple-inference/simple-inference.cpp
Normal file
|
@ -0,0 +1,819 @@
|
|||
// Defines sigaction on msys:
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#endif
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include "console.h"
|
||||
#include "llama.h"
|
||||
#include "build-info.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#elif defined (_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#define SI_DUMP_SEQUENCES_INTERVAL 40
|
||||
|
||||
static std::atomic<bool> interrupted {false};
|
||||
static std::atomic<bool> done {false};
|
||||
|
||||
typedef struct tokens_chunk {
|
||||
bool is_input;
|
||||
size_t consumed;
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
tokens_chunk(const bool is_input = false, const size_t consumed = 0, const std::vector<llama_token> & tokens = {})
|
||||
: is_input(is_input)
|
||||
, consumed(consumed)
|
||||
, tokens(tokens)
|
||||
{}
|
||||
} tokens_chunk;
|
||||
|
||||
enum seq_state {
|
||||
SEQ_GENERATING,
|
||||
SEQ_SHARE_PROMPT,
|
||||
SEQ_INPUT,
|
||||
SEQ_DONE,
|
||||
};
|
||||
|
||||
typedef struct seq_ctx {
|
||||
llama_seq_id seq_id;
|
||||
int32_t batch_idx;
|
||||
enum seq_state state;
|
||||
size_t n_remain;
|
||||
size_t n_generated;
|
||||
llama_sampling_context *ctx_sampling;
|
||||
|
||||
llama_token last_sampled;
|
||||
std::vector<tokens_chunk> chunks;
|
||||
// std::vector<llama_token> pending_input;
|
||||
// std::vector<llama_token> output;
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
size_t high_water_mark;
|
||||
struct seqrep_rewind_state rewind_state;
|
||||
size_t rewind_count;
|
||||
size_t rewind_tokens;
|
||||
#endif
|
||||
} seq_ctx;
|
||||
|
||||
|
||||
typedef struct gen_ctx {
|
||||
llama_context * ctx = nullptr;
|
||||
llama_model * model = nullptr;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
|
||||
llama_batch batch;
|
||||
gpt_params params;
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
|
||||
|
||||
int n_ctx;
|
||||
int n_vocab;
|
||||
|
||||
std::vector<llama_token> scratch;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
size_t prompt_size = 0;
|
||||
|
||||
llama_seq_id focused_sequence = 0;
|
||||
|
||||
size_t decode_count = 0;
|
||||
int64_t decode_time_total = 0, decode_time_last = 0;
|
||||
|
||||
std::vector<seq_ctx> ctxs_seq;
|
||||
|
||||
private:
|
||||
bool init_params(const int argc, char ** argv);
|
||||
bool init_model();
|
||||
bool init_prompt();
|
||||
bool init_handlers();
|
||||
bool init_sampling();
|
||||
bool init_batches();
|
||||
|
||||
public:
|
||||
gen_ctx(const int argc, char ** argv);
|
||||
~gen_ctx();
|
||||
void dump_batches(const size_t prompt_start = 0);
|
||||
void dump_chunks(const std::vector<tokens_chunk> & chunks, const size_t start_offset = 0);
|
||||
void dump_batch(const size_t seq);
|
||||
void handle_seq(seq_ctx & sctx);
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
void handle_seq_seqrep(seq_ctx & sctx);
|
||||
#endif
|
||||
bool feed_prompt(
|
||||
const std::vector<llama_token> & tokens,
|
||||
llama_pos pos = 0,
|
||||
llama_seq_id seq = 0);
|
||||
bool go();
|
||||
} gen_ctx;
|
||||
|
||||
|
||||
static void concat_chunks(const std::vector<tokens_chunk> & chunks, std::vector<llama_token> & dst, const size_t start_offset) {
|
||||
size_t offset = 0;
|
||||
|
||||
for (const tokens_chunk & chunk : chunks) {
|
||||
if (offset + chunk.tokens.size() <= start_offset) {
|
||||
offset += chunk.tokens.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_t chunk_offset = start_offset - offset;
|
||||
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
|
||||
const llama_token * tp = chunk.tokens.data() + chunk_offset;
|
||||
|
||||
for (size_t i = 0; i < chunk_size; i++, tp++) {
|
||||
dst.push_back(*tp);
|
||||
}
|
||||
offset += chunk_size;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void write_logfile(
|
||||
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
||||
const std::vector<llama_token> input_tokens, const std::string output, const std::vector<llama_token> output_tokens) {
|
||||
|
||||
if (params.logdir.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string timestamp = get_sortable_timestamp();
|
||||
|
||||
const bool success = create_directory_with_parents(params.logdir);
|
||||
if (!success) {
|
||||
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
|
||||
__func__, params.logdir.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string logfile_path = params.logdir + timestamp + ".yml";
|
||||
FILE * logfile = fopen(logfile_path.c_str(), "w");
|
||||
|
||||
if (logfile == NULL) {
|
||||
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(logfile, "binary: simple-inference\n");
|
||||
char model_desc[128];
|
||||
llama_model_desc(model, model_desc, sizeof(model_desc));
|
||||
dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc);
|
||||
|
||||
fprintf(logfile, "\n");
|
||||
fprintf(logfile, "######################\n");
|
||||
fprintf(logfile, "# Generation Results #\n");
|
||||
fprintf(logfile, "######################\n");
|
||||
fprintf(logfile, "\n");
|
||||
|
||||
dump_string_yaml_multiline(logfile, "output", output.c_str());
|
||||
dump_vector_int_yaml(logfile, "output_tokens", output_tokens);
|
||||
|
||||
llama_dump_timing_info_yaml(logfile, ctx);
|
||||
fclose(logfile);
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
if (interrupted) {
|
||||
done.store(true);
|
||||
} else {
|
||||
interrupted.store(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
static bool check_unsupported(const gpt_params * params) {
|
||||
std::string nope;
|
||||
const llama_sampling_params & sparams = params->sparams;
|
||||
|
||||
if (params->embedding)
|
||||
nope = "embedding";
|
||||
else if (!sparams.grammar.empty())
|
||||
nope = "grammar"; // Currently broken most likely
|
||||
else if (sparams.cfg_scale != 1.0f)
|
||||
nope = "cfg_scale";
|
||||
else if (!sparams.cfg_negative_prompt.empty())
|
||||
nope = "cfg_negative_prompt";
|
||||
else if (!params->path_prompt_cache.empty())
|
||||
nope = "prompt cache";
|
||||
else if (params->escape)
|
||||
nope = "prompt escaping";
|
||||
else if (params->interactive || params->interactive_first || params->instruct)
|
||||
nope = "interactive mode";
|
||||
else if (!params->input_prefix.empty() || !params->input_suffix.empty() || params->input_prefix_bos)
|
||||
nope = "input prefix or suffix";
|
||||
else if (params->hellaswag)
|
||||
nope = "hellaswag";
|
||||
else if (params->n_keep != 0)
|
||||
nope = "keep";
|
||||
else if (!params->antiprompt.empty())
|
||||
nope = "reverse prompt";
|
||||
if (!nope.empty()) {
|
||||
LOG_TEE("%s: error: We don't support %s here.\n", __func__, nope.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool gen_ctx::init_params(const int argc, char ** argv) {
|
||||
if (gpt_params_parse(argc, argv, params) == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!check_unsupported(¶ms)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (params.rope_freq_base != 10000.0) {
|
||||
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
|
||||
}
|
||||
|
||||
if (params.rope_freq_scale != 1.0) {
|
||||
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
|
||||
}
|
||||
|
||||
if (params.n_ctx < 8) {
|
||||
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
|
||||
params.n_ctx = 8;
|
||||
}
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
if (params.random_prompt) {
|
||||
params.prompt = gpt_random_prompt(rng);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool gen_ctx::init_model() {
|
||||
LOG("%s: llama backend init\n", __func__);
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_TEE("%s: error: unable to load model\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
|
||||
n_ctx = llama_n_ctx(ctx);
|
||||
n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool gen_ctx::init_prompt() {
|
||||
const bool add_bos = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
|
||||
LOG("add_bos: %d\n", add_bos);
|
||||
|
||||
if (!params.prompt.empty()) {
|
||||
LOG("tokenize the prompt\n");
|
||||
prompt_tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
}
|
||||
|
||||
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
||||
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, prompt_tokens).c_str());
|
||||
|
||||
// Should not run without any tokens
|
||||
if (prompt_tokens.empty()) {
|
||||
prompt_tokens.push_back(llama_token_bos(model));
|
||||
LOG("input was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, prompt_tokens).c_str());
|
||||
}
|
||||
|
||||
LOG("n_ctx: %d\n", n_ctx);
|
||||
|
||||
if ((int) prompt_tokens.size() > n_ctx - 4) {
|
||||
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4);
|
||||
return false;
|
||||
}
|
||||
prompt_size = prompt_tokens.size();
|
||||
|
||||
if (params.verbose_prompt) {
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size());
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
||||
LOG_TEE("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool gen_ctx::init_handlers() {
|
||||
// save choice to use color for later
|
||||
// (note for later: this is a slightly awkward choice)
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = sigint_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
bool gen_ctx::init_sampling() {
|
||||
LOG_TEE("sampling: %s\n", llama_sampling_print(sparams).c_str());
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
for (auto & sr_params : sparams.seqrep_params) {
|
||||
seqrep_sampler_params_dump(&sr_params);
|
||||
}
|
||||
#endif
|
||||
ctx_sampling = llama_sampling_init(sparams);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool gen_ctx::init_batches() {
|
||||
batch = llama_batch_init(std::max(int32_t(prompt_size), params.n_batch), 0, 1);
|
||||
int n_remain = params.n_predict;
|
||||
|
||||
if (n_remain < 0 || n_remain + int(prompt_size) > n_ctx) {
|
||||
n_remain = n_ctx - prompt_size;
|
||||
}
|
||||
|
||||
ctxs_seq.reserve(params.n_parallel);
|
||||
for (int32_t i = 0; i < params.n_parallel; i++) {
|
||||
seq_ctx && bs = {
|
||||
llama_seq_id(i),
|
||||
-1,
|
||||
i == 0 ? SEQ_INPUT : SEQ_SHARE_PROMPT,
|
||||
size_t(n_remain),
|
||||
0,
|
||||
llama_sampling_init(params.sparams),
|
||||
-1,
|
||||
{},
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
prompt_size + 1,
|
||||
seqrep_rewind_state(n_vocab, n_ctx, 2000),
|
||||
0,
|
||||
0,
|
||||
#endif
|
||||
};
|
||||
GGML_ASSERT(prompt_size > 0);
|
||||
bs.chunks.emplace_back(true, 0, prompt_tokens);
|
||||
if (i > 0) {
|
||||
bs.chunks.emplace_back(false, 0, std::vector<llama_token>());
|
||||
}
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
seqrep_rewind_slot & rw_slot = bs.rewind_state.get_rewind_slot(0);
|
||||
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
// llama_sampling_cp(bs.ctx_sampling, rw_slot.ctx_sampling);
|
||||
// bs.rewind_state.set_logits_slot(ctx, 0, (prompt_size - 1) % params.n_batch);
|
||||
#endif
|
||||
ctxs_seq.push_back(bs);
|
||||
// if (i > 0) llama_kv_cache_seq_cp(ctx, 0, i, 0, prompt_size);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
gen_ctx::gen_ctx(const int argc, char ** argv) {
|
||||
bool result = true;
|
||||
|
||||
result = result && init_params(argc, argv);
|
||||
result = result && init_model();
|
||||
result = result && init_prompt();
|
||||
result = result && init_handlers();
|
||||
result = result && init_sampling();
|
||||
result = result && init_batches();
|
||||
if (!result) {
|
||||
throw std::runtime_error("Initialization failed");
|
||||
}
|
||||
}
|
||||
|
||||
gen_ctx::~gen_ctx() {
|
||||
for (auto & sctx : ctxs_seq) {
|
||||
llama_sampling_free(sctx.ctx_sampling);
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
for (auto & rs : sctx.rewind_state.rewind_slots) {
|
||||
if (rs.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(rs.ctx_sampling);
|
||||
rs.ctx_sampling = nullptr;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
llama_sampling_free(ctx_sampling);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
|
||||
bool gen_ctx::feed_prompt(const std::vector<llama_token> & tokens, llama_pos pos, llama_seq_id seq) {
|
||||
int32_t tokens_remain = tokens.size();
|
||||
const llama_token * tokens_curr = tokens.data();
|
||||
|
||||
console::set_display(console::prompt);
|
||||
while (tokens_remain > 0 && !interrupted) {
|
||||
const int32_t chunk_size = std::min(int32_t(tokens_remain), params.n_batch);
|
||||
llama_batch_clear(batch);
|
||||
for (int i = 0; i < chunk_size; i++) {
|
||||
llama_batch_add(batch, tokens_curr[i], pos + i, {seq}, false);
|
||||
}
|
||||
pos += batch.n_tokens;
|
||||
tokens_remain -= batch.n_tokens;
|
||||
batch.logits[batch.n_tokens - 1] = tokens_remain < 1;
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
console::set_display(console::reset);
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
decode_count++;
|
||||
|
||||
// display text
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
const std::string token_str = llama_token_to_piece(ctx, tokens_curr[i]);
|
||||
fputs(token_str.c_str(), stdout);
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
tokens_curr += batch.n_tokens;
|
||||
}
|
||||
console::set_display(console::reset);
|
||||
return true;
|
||||
}
|
||||
|
||||
void gen_ctx::dump_chunks(const std::vector<tokens_chunk> & chunks, const size_t start_offset) {
|
||||
size_t offset = 0;
|
||||
bool prompt_mode = false;
|
||||
console::set_display(console::reset);
|
||||
|
||||
for (const tokens_chunk & chunk : chunks) {
|
||||
if (offset + chunk.tokens.size() < start_offset) {
|
||||
offset += chunk.tokens.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_t chunk_offset = start_offset - offset;
|
||||
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
|
||||
const llama_token * tp = chunk.tokens.data() + chunk_offset;
|
||||
|
||||
if (chunk.is_input != prompt_mode) {
|
||||
prompt_mode = chunk.is_input;
|
||||
console::set_display(prompt_mode ? console::prompt : console::reset);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < chunk_size; i++, tp++) {
|
||||
const std::string token_str = llama_token_to_piece(ctx, *tp);
|
||||
fputs(token_str.c_str(), stdout);
|
||||
}
|
||||
}
|
||||
if (prompt_mode) {
|
||||
console::set_display(console::reset);
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
void gen_ctx::dump_batches(const size_t prompt_start) {
|
||||
|
||||
bool first = true;
|
||||
|
||||
for (seq_ctx & sctx : ctxs_seq) {
|
||||
if (sctx.seq_id == focused_sequence) continue;
|
||||
printf("\n\n%s Result #%d (size: %zu",
|
||||
!first ? "====================" : "####################",
|
||||
sctx.seq_id + 1, prompt_size + sctx.n_generated);
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
|
||||
#endif
|
||||
puts("):");
|
||||
dump_chunks(sctx.chunks, prompt_start);
|
||||
first = false;
|
||||
}
|
||||
seq_ctx & sctx = ctxs_seq[focused_sequence];
|
||||
printf("\n\n%s Result #%d (size: %zu",
|
||||
!first ? "====================" : "####################",
|
||||
sctx.seq_id + 1, prompt_size + sctx.n_generated);
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
|
||||
#endif
|
||||
puts("):");
|
||||
dump_chunks(sctx.chunks, prompt_start);
|
||||
}
|
||||
|
||||
void gen_ctx::handle_seq(seq_ctx & sctx) {
|
||||
switch (sctx.state) {
|
||||
case SEQ_DONE:
|
||||
case SEQ_SHARE_PROMPT: break;
|
||||
|
||||
case SEQ_GENERATING: {
|
||||
GGML_ASSERT(sctx.batch_idx >= 0);
|
||||
scratch.resize(prompt_size);
|
||||
concat_chunks(sctx.chunks, scratch, prompt_size);
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
handle_seq_seqrep(sctx);
|
||||
#endif
|
||||
sctx.last_sampled = llama_sampling_sample(ctx_sampling, ctx, NULL, sctx.batch_idx, scratch);
|
||||
llama_sampling_accept(sctx.ctx_sampling, ctx, sctx.last_sampled, true);
|
||||
if (sctx.seq_id == focused_sequence) {
|
||||
const std::string token_str = llama_token_to_piece(ctx, sctx.last_sampled);
|
||||
fputs(token_str.c_str(), stdout);
|
||||
fflush(stdout);
|
||||
}
|
||||
sctx.n_generated++;
|
||||
sctx.n_remain--;
|
||||
if (sctx.chunks.empty() || sctx.chunks.back().is_input) {
|
||||
sctx.chunks.emplace_back(0, false, std::vector<llama_token>());
|
||||
}
|
||||
sctx.chunks.back().tokens.push_back(sctx.last_sampled);
|
||||
if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) {
|
||||
sctx.state = SEQ_DONE;
|
||||
sctx.batch_idx = -1;
|
||||
// LOG_TEE(" [end of text]\n");
|
||||
// break;
|
||||
} else {
|
||||
sctx.batch_idx = batch.n_tokens;
|
||||
llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_generated, {sctx.seq_id}, true);
|
||||
}
|
||||
} break;
|
||||
|
||||
case SEQ_INPUT: {
|
||||
sctx.last_sampled = -1;
|
||||
GGML_ASSERT(!sctx.chunks.empty());
|
||||
tokens_chunk & chunk = sctx.chunks.back();
|
||||
GGML_ASSERT(chunk.is_input);
|
||||
|
||||
const size_t remain = chunk.tokens.size() - chunk.consumed;
|
||||
const size_t to_consume = std::min(size_t(params.n_batch), remain);
|
||||
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
|
||||
llama_batch_add(batch, chunk.tokens[i], llama_pos(i), {sctx.seq_id}, false);
|
||||
}
|
||||
chunk.consumed += to_consume;
|
||||
if (chunk.consumed == chunk.tokens.size()) {
|
||||
sctx.batch_idx = batch.n_tokens - 1;
|
||||
batch.logits[sctx.batch_idx] = true;
|
||||
} else {
|
||||
sctx.batch_idx = -1;
|
||||
}
|
||||
} break;
|
||||
|
||||
default:
|
||||
throw std::runtime_error("Unexpected state in handle_seq");
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
void gen_ctx::handle_seq_seqrep(seq_ctx & sctx) {
|
||||
if (sctx.n_generated > 0) {
|
||||
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_generated);
|
||||
if (rw_slot.ctx_sampling == nullptr) {
|
||||
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
}
|
||||
llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling);
|
||||
sctx.rewind_state.set_logits_slot(ctx, sctx.n_generated, sctx.batch_idx);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
std::vector<llama_token> seq_last_tokens;
|
||||
seq_last_tokens.reserve(sctx.n_generated);
|
||||
concat_chunks(sctx.chunks, seq_last_tokens, prompt_size);
|
||||
|
||||
size_t rewind_distance =
|
||||
llama_seqrep_handle_rewind(
|
||||
ctx, sctx.rewind_state, seq_last_tokens, sctx.n_generated, prompt_tokens,
|
||||
sparams.seqrep_params, &sctx.high_water_mark, sctx.batch_idx);
|
||||
if (rewind_distance < 1) {
|
||||
return;
|
||||
}
|
||||
// if (sctx.seq_id != 0) printf("<%d:%zu>", sctx.seq_id + 1, rewind_distance);
|
||||
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
|
||||
const size_t slot_idx = sctx.n_generated - rewind_distance;
|
||||
const llama_token nl_id = llama_token_nl(model);
|
||||
|
||||
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(slot_idx);
|
||||
llama_sampling_cp(rw_slot.ctx_sampling, sctx.ctx_sampling);
|
||||
|
||||
if (sctx.seq_id == focused_sequence) {
|
||||
console::set_display(console::error);
|
||||
fputs("\u3010", stdout);
|
||||
// printf("%zu,%zu,%zu", rewind_distance, sctx.n_generated, sctx.generated_tokens.size());
|
||||
for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) {
|
||||
if (seq_last_tokens[i] == nl_id) {
|
||||
fputs("\\n", stdout);
|
||||
continue;
|
||||
}
|
||||
const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]);
|
||||
// fputs("|", stdout);
|
||||
fputs(token_str.c_str(), stdout);
|
||||
}
|
||||
fputs("\u3011", stdout);
|
||||
console::set_display(console::reset);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
sctx.n_remain += rewind_distance;
|
||||
sctx.n_generated -= rewind_distance;
|
||||
sctx.rewind_count++;
|
||||
sctx.rewind_tokens += rewind_distance;
|
||||
llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_generated + 1, -1);
|
||||
while (!sctx.chunks.empty() && rewind_distance > 0) {
|
||||
tokens_chunk & last_chunk = sctx.chunks.back();
|
||||
GGML_ASSERT(!last_chunk.is_input);
|
||||
|
||||
if (last_chunk.tokens.size() >= rewind_distance) {
|
||||
last_chunk.tokens.resize(last_chunk.tokens.size() - rewind_distance);
|
||||
rewind_distance = 0;
|
||||
break;
|
||||
}
|
||||
rewind_distance -= last_chunk.tokens.size();
|
||||
sctx.chunks.pop_back();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
bool gen_ctx::go() {
|
||||
if (ctxs_seq.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decode_count == 0) {
|
||||
scratch.reserve(n_ctx);
|
||||
scratch.resize(prompt_size);
|
||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), scratch.begin());
|
||||
// FIXME: Hacky.
|
||||
if (!feed_prompt(prompt_tokens)) {
|
||||
throw std::runtime_error("Prompt processing failed");
|
||||
}
|
||||
for (auto & sctx : ctxs_seq) {
|
||||
sctx.batch_idx = batch.n_tokens - 1;
|
||||
sctx.state = SEQ_GENERATING;
|
||||
if (sctx.seq_id == 0) {
|
||||
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
|
||||
} else {
|
||||
llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size);
|
||||
}
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(0);
|
||||
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling);
|
||||
sctx.rewind_state.set_logits_slot(ctx, 0, sctx.batch_idx);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_clear(batch);
|
||||
for (auto & sctx : ctxs_seq) {
|
||||
handle_seq(sctx);
|
||||
}
|
||||
if (batch.n_tokens == 0) return false;
|
||||
|
||||
decode_time_last = ggml_time_us();
|
||||
const int decode_result = llama_decode(ctx, batch);
|
||||
decode_time_last = std::max(int64_t(0), ggml_time_us() - decode_time_last);
|
||||
decode_time_total += decode_time_last;
|
||||
|
||||
if (decode_result != 0) {
|
||||
LOG_TEE("%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens,
|
||||
decode_result == 1 ? "couldn't find slot" : "unknown error");
|
||||
return false;
|
||||
}
|
||||
decode_count++;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool handle_commands(gen_ctx & gctx) {
|
||||
std::string line;
|
||||
line.reserve(1024);
|
||||
|
||||
|
||||
puts("");
|
||||
fflush(stdout);
|
||||
while (1) {
|
||||
printf("> ");
|
||||
console::readline(line, false);
|
||||
console::set_display(console::reset);
|
||||
while (!line.empty() && std::isspace(line.back())) {
|
||||
line.pop_back();
|
||||
}
|
||||
if (line.empty()) break;
|
||||
if (line.size() < 2 || line.front() != '/') {
|
||||
printf("\n* Bad command\n");
|
||||
continue;
|
||||
}
|
||||
size_t sep_idx = line.find(' ');
|
||||
std::string command, rest;
|
||||
if (sep_idx != std::string::npos) {
|
||||
command = line.substr(1, sep_idx - 1);
|
||||
rest = line.substr(sep_idx + 1);
|
||||
} else {
|
||||
command = line.substr(1);
|
||||
}
|
||||
|
||||
if (command == "quit") return false;
|
||||
|
||||
// Focus
|
||||
if (isdigit(command[0])) {
|
||||
const int target = std::atoi(command.c_str());
|
||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
||||
printf("\n* Focus: Bad seq id\n");
|
||||
} else {
|
||||
gctx.focused_sequence = llama_seq_id(target - 1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (command == "kill") {
|
||||
const int target = std::atoi(rest.c_str());
|
||||
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
|
||||
printf("\n* Kill: Bad seq id\n");
|
||||
} else if (target - 1 == gctx.focused_sequence) {
|
||||
printf("\n* Kill: Can't kill focus\n");
|
||||
} else {
|
||||
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
printf("\n* Bad command\n");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gen_ctx gctx(argc, argv);
|
||||
|
||||
while (gctx.go() && !done) {
|
||||
bool need_dump = gctx.params.n_parallel > 1 && gctx.decode_count % SI_DUMP_SEQUENCES_INTERVAL == 0;
|
||||
if (interrupted) {
|
||||
if (!handle_commands(gctx)) break;
|
||||
if (done) break;
|
||||
interrupted = false;
|
||||
need_dump = true;
|
||||
}
|
||||
if (need_dump) {
|
||||
printf("\n-- Last decode[%zu]: %.3f, avg: %.3f",
|
||||
gctx.decode_count, double(gctx.decode_time_last) / 1000000,
|
||||
(double(gctx.decode_time_total) / 1000000) / double(gctx.decode_count));
|
||||
gctx.dump_batches((gctx.prompt_size > 20) ? (gctx.prompt_size - 10) : 0);
|
||||
}
|
||||
}
|
||||
gctx.focused_sequence = gctx.ctxs_seq.size() - 1;
|
||||
gctx.dump_batches();
|
||||
puts("");
|
||||
console::cleanup();
|
||||
|
||||
llama_print_timings(gctx.ctx);
|
||||
}
|
|
@ -1,10 +1,14 @@
|
|||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
#include "common/seqrep-sampler.h"
|
||||
#endif
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
|
@ -128,6 +132,79 @@ static void test_repetition_penalties(
|
|||
}
|
||||
}
|
||||
|
||||
// FIXME: This should probably just be moved to a separate test executable.
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
// NOTE: Compares expected_probs at id position, not sorted position like the other
|
||||
// test functions.
|
||||
static void test_seqrep_penalty(
|
||||
const std::vector<float> & probs,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
const std::vector<float> & expected_probs,
|
||||
const llama_sampler_seqrep_params * params) {
|
||||
assert(probs.size() == expected_probs.size());
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
float logit = log(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_seqrep_penalty(nullptr, &candidates_p, last_tokens, params);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[candidates_p.data[i].id]) < 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
static void run_seqrep_tests(void) {
|
||||
llama_sampler_seqrep_params params;
|
||||
|
||||
// Compatible with frequency/presence penalty
|
||||
memset(¶ms, 0, sizeof(llama_sampler_seqrep_params));
|
||||
params.last_n = 1024;
|
||||
params.min_length = 1;
|
||||
params.mid_word_scale = 1.0f;
|
||||
params.presence_penalty = 5.0f;
|
||||
params.length_penalty = 5.0f;
|
||||
params.flags |= LLAMA_SEQREP_ABSOLUTE_PENALTY;
|
||||
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.000011f, 0.249997f, 0.249997f, 0.249997f, 0.249997f}, ¶ms);
|
||||
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.000023f, 0.000023f, 0.000023f, 0.499966f, 0.499966f}, ¶ms);
|
||||
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.000000f, 0.000023f, 0.000023f, 0.499977f, 0.499977f}, ¶ms);
|
||||
|
||||
// Compatible with repetition penalty
|
||||
memset(¶ms, 0, sizeof(llama_sampler_seqrep_params));
|
||||
params.last_n = 1024;
|
||||
params.min_length = 1;
|
||||
params.mid_word_scale = 1.0f;
|
||||
params.presence_penalty = 50.0f;
|
||||
params.length_penalty = 1.0f;
|
||||
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0, 0.25f, 0.25f, 0.25f, 0.25f}, ¶ms);
|
||||
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0, 0, 0, 0.5f, 0.5f}, ¶ms);
|
||||
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5f, 0.5f}, ¶ms);
|
||||
|
||||
// Seqrep mode
|
||||
// memset(¶ms, 0, sizeof(llama_sampler_seqrep_params));
|
||||
// params.last_n = 1024;
|
||||
// params.min_length = 3;
|
||||
// params.mid_word_scale = 1.0f;
|
||||
// params.presence_penalty = 50.0f;
|
||||
// params.length_penalty = 1.0f;
|
||||
// test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, ¶ms);
|
||||
// test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.20f, 0.20f, 0.20f, 0.20f, 0.20f}, ¶ms);
|
||||
// params.tolerance = 1.0f;
|
||||
// test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, ¶ms);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
int main(void) {
|
||||
ggml_time_init();
|
||||
|
||||
|
@ -154,6 +231,10 @@ int main(void) {
|
|||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
||||
|
||||
#ifndef LLAMA_NO_SEQREP_SAMPLER
|
||||
run_seqrep_tests();
|
||||
#endif
|
||||
|
||||
printf("OK\n");
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue