Implementation of a sequence repetition penalty

This commit is contained in:
KerfuffleV2 2023-08-12 14:30:45 -06:00
parent 875fb42871
commit 11fa3dfd69
14 changed files with 2192 additions and 3 deletions

View file

@ -96,6 +96,8 @@ option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging"
option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_MPI "llama: use MPI" OFF)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_SEQREP_SAMPLER "llama: build with support for seqrep sampler" ON)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" ON) option(LLAMA_BUILD_SERVER "llama: build server example" ON)

View file

@ -2,7 +2,7 @@
BUILD_TARGETS = \ BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \ simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o speculative infill benchmark-matmult parallel finetune export-lora simple-inference tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
TEST_TARGETS = \ TEST_TARGETS = \
@ -559,6 +559,14 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
train.o: common/train.cpp common/train.h train.o: common/train.cpp common/train.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
ifndef LLAMA_NO_SEQREP_SAMPLER
COMMON_H_DEFS += common/seqrep-sampler.h
COMMON_DEPS += seqrep-sampler.o
seqrep-sampler.o: common/seqrep-sampler.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
endif
libllama.so: llama.o ggml.o $(OBJS) libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
@ -581,6 +589,9 @@ infill: examples/infill/infill.cpp ggml.o llama.o $(C
simple: examples/simple/simple.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) simple: examples/simple/simple.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
simple-inference: examples/simple-inference/simple-inference.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
batched: examples/batched/batched.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) batched: examples/batched/batched.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

View file

@ -111,6 +111,8 @@ pub fn build(b: *std.build.Builder) !void {
var make = try Maker.init(b); var make = try Maker.init(b);
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false; make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
try make.addFlag("-DLLAMA_NO_SEQREP_SAMPLER");
const ggml = make.obj("ggml", "ggml.c"); const ggml = make.obj("ggml", "ggml.c");
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c"); const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c"); const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");

View file

@ -54,6 +54,12 @@ add_library(${TARGET} STATIC
train.cpp train.cpp
) )
if (LLAMA_SEQREP_SAMPLER)
target_sources(${TARGET} PRIVATE seqrep-sampler.h seqrep-sampler.cpp)
else()
add_compile_definitions(LLAMA_NO_SEQREP_SAMPLER)
endif()
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif() endif()

View file

@ -1,6 +1,10 @@
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
#ifndef LLAMA_NO_SEQREP_SAMPLER
#include "seqrep-sampler.h"
#endif
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -335,6 +339,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
sparams.penalty_present = std::stof(argv[i]); sparams.penalty_present = std::stof(argv[i]);
#ifndef LLAMA_NO_SEQREP_SAMPLER
} else if (arg == "-seqrep" || arg == "--seqrep-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
if (std::strcmp(argv[i], "help") == 0) {
seqrep_sampler_help();
exit(0);
}
llama_sampler_seqrep_params sr_params;
seqrep_sampler_params_init(&sr_params);
if (!seqrep_sampler_params_parse(argv[i], &sr_params)) {
seqrep_sampler_help();
exit(1);
}
sparams.seqrep_params.push_back(sr_params);
#endif
} else if (arg == "--mirostat") { } else if (arg == "--mirostat") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -764,6 +786,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
#ifndef LLAMA_NO_SEQREP_SAMPLER
printf(" -seqrep CFG, --seqrep-penalty CFG\n");
printf(" add a copy of the sequence repetition penalty sampler. may be specified multiple times. for help: -seqrep help\n");
#endif
printf(" --mirostat N use Mirostat sampling.\n"); printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);

View file

@ -103,7 +103,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx) { const int idx,
const std::vector<llama_token> & all_last_tokens) {
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@ -155,6 +156,13 @@ llama_token llama_sampling_sample(
prev.data() + prev.size() - penalty_last_n, prev.data() + prev.size() - penalty_last_n,
penalty_last_n, penalty_repeat, penalty_freq, penalty_present); penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
#ifndef LLAMA_NO_SEQREP_SAMPLER
for (auto & sr_params : params.seqrep_params) {
if ((sr_params.flags & LLAMA_SEQREP_REWIND_MODE) != 0) continue;
llama_sample_seqrep_penalty(ctx_main, &cur_p, all_last_tokens, &sr_params);
}
#endif
if (!penalize_nl) { if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) { for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {

View file

@ -4,6 +4,10 @@
#include "grammar-parser.h" #include "grammar-parser.h"
#ifndef LLAMA_NO_SEQREP_SAMPLER
#include "seqrep-sampler.h"
#endif
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
@ -35,6 +39,11 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // how strong is guidance float cfg_scale = 1.f; // how strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
#ifndef LLAMA_NO_SEQREP_SAMPLER
std::vector<llama_sampler_seqrep_params> seqrep_params;
#endif
} llama_sampling_params; } llama_sampling_params;
// general sampler context // general sampler context
@ -101,7 +110,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
int idx = 0); int idx = 0,
const std::vector<llama_token> & all_last_tokens = {});
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,

1014
common/seqrep-sampler.cpp Normal file

File diff suppressed because it is too large Load diff

196
common/seqrep-sampler.h Normal file
View file

@ -0,0 +1,196 @@
#pragma once
#include <stddef.h>
#include <vector>
#include <regex>
#include "llama.h"
enum llama_sampler_seqrep_flags {
// Tolerance charges can't be used consecutively.
LLAMA_SEQREP_TOLERANCE_NO_CONSECUTIVE = (1 << 0),
// Tolerance charges can't be used before the first actual match.
LLAMA_SEQREP_TOLERANCE_NO_FIRST = (1 << 1),
// When applying the length penalty, use the length of the longest observed
// sequence matching the token rather than the total length of
// sequences matching the token. In other words, if we find a sequence
// of length 3 and a sequence of length 4 continued by token 69 then
// with this flag on we penalize based on length 4, with it off we
// penalize based on length 7 (3 + 4).
LLAMA_SEQREP_PENALIZE_LENGTH_MAX_SEEN = (1 << 2),
// Apply an absolute penalty rather than dividing the logit by the penalty.
LLAMA_SEQREP_ABSOLUTE_PENALTY = (1 << 3),
// Rewind to cut off the head of sequences rather than the end.
// Ignored when min_length < 2.
// Since it wouldn't make sense to rewind and then let sampling pick
// the same token again, penalty values and mid_word_scale have no
// effect.
LLAMA_SEQREP_REWIND_MODE = (1 << 4),
// When rewinding, skip past whitespace and punctuation. For example,
// if the matched sequence was "<NL>'hello" then we will rewind to the
// token starting with 'h' and ban it.
LLAMA_SEQREP_REWIND_SKIP_WS_PUNCT = (1 << 5),
// Rewind to the shortest matching sequence of at least min_length rather than the longest.
LLAMA_SEQREP_REWIND_USE_SHORTEST_MATCH = (1 << 6),
// Rewinding requires a word boundary. Only has an effect when rewind_seek_word_boundary isn't 0.
LLAMA_SEQREP_REWIND_REQUIRE_WBOUND = (1 << 7),
// Persisted bans are only applied if at a word bound.
LLAMA_SEQREP_REWIND_PERSIST_REQUIRE_WBOUND = (1 << 8),
};
typedef struct llama_sampler_seqrep_params {
// The minimum length of a matching sequence of tokens. When this is < 2 then
// the sampler works in single token mode and tolerance is ignored.
size_t min_length;
// Maximum length for a matching sequence of tokens.
size_t max_length;
// Starting offset for matching against the end of the sequence. This can be used
// to only match against sequences in the initial prompt, for example. Matching
// starts at the offset and moves toward the beginning of the list.
// Use 0 for penultimate token when min_length > 1 otherwise 0 for last token.
size_t start_offset;
// Window of last tokens to consider, starting from the end. < 0 means
// the whole list.
int last_n;
// Flags based on llama_sampler_seqrep_flags enum values ORed together.
int flags;
// Tolerance for non-matching tokens in a sequence.
float tolerance;
// Flat penalty applied to the token that can continue a repeated sequence.
float presence_penalty;
// Scaling penalty applied to the token that can continue a repeated sequence.
// The penalty is multiplied by the total length of sequences that are continued by this token unless
// the PENALIZE_LENGTH_MAX_SEEN is set.
float length_penalty;
// Scale for penalizing tokens from repeated sequences that aren't at/form a word boundary.
float mid_word_scale;
// Tolerance credit per real match. I.E. .5 means +1 tolerance per 2 matched tokens.
float tolerance_match_credit;
// Caps tolerance at the specified value. Only meaningful when tolerance_match_credit > 0
float tolerance_cap;
// Ensure the sequence is at least the specified length in rewind mode after
// whitespace skipping and other modifications.
size_t rewind_min_length;
// When rewinding, try to find a word boundary within the specified distance, starting with tokens earlier than the rewind point.
size_t rewind_seek_word_boundary;
// A position is limited to the specified number of rewinds. When the limit is exceeded, future rewinds cannot target it or earlier tokens.
size_t rewind_max_visits;
// Tokens banned by rewind remain banned for an additional number of positions equal to the value. i.e. setting this to 1 would mean the token is banned for 2 positions.
size_t rewind_persist_bans;
// Number of tokens from the sequence to ban when rewinding.
size_t rewind_ban_length;
std::vector<std::wregex> include_re;
std::vector<std::wregex> exclude_re;
} llama_sampler_seqrep_params;
enum seqrep_check_word_flags {
SEQREP_CW_START_IS_WBOUND = 1 << 0,
SEQREP_CW_END_IS_WBOUND = 1 << 1,
SEQREP_CW_ALL_WS_PUNCT = 1 << 2,
SEQREP_CW_START_IS_INVALID = 1 << 3, // Start of token is invalid/incomplete UTF8
SEQREP_CW_END_IS_INVALID = 1 << 4 // End of token is invalid/incomplete UTF8
};
struct seqrep_logit_info {
const int n_vocab;
std::vector<llama_token_data> token_data;
seqrep_logit_info(llama_context * ctx, const size_t k, const int32_t ith);
const std::vector<llama_token_data> & get_token_data(void);
llama_token_data get_token_id(const llama_token token_id) const;
void rebuild(llama_context *ctx, const size_t k, const int32_t ith);
void populate_logits(float * logits);
// Yoinked from beam search code.
// Return top k token_data by logit.
std::vector<llama_token_data> top_k(const float * const logits, const size_t k);
};
struct seqrep_rewind_slot {
size_t count;
std::vector<llama_token> tokens;
struct llama_sampling_context * ctx_sampling = nullptr;
};
struct seqrep_rewind_state {
const size_t n_vocab;
const size_t n_ctx;
const size_t k;
std::vector<seqrep_logit_info> logit_slots;
std::vector<seqrep_rewind_slot> rewind_slots;
seqrep_rewind_state(
const size_t n_vocab,
const size_t n_ctx,
const size_t k = 2000);
struct seqrep_rewind_slot & get_rewind_slot(const size_t idx);
void set_logits_slot(llama_context * ctx, const size_t idx, const int32_t ith = 0);
void populate_logits(llama_context * ctx, const size_t idx, const int32_t ith = 0);
};
// Sequence repetition penalty with semi-fuzzy matching. Note: Handles the last_n window itself.
size_t llama_sample_seqrep_penalty(
struct llama_context * ctx,
llama_token_data_array * candidates,
const std::vector<llama_token> & last_tokens,
const llama_sampler_seqrep_params * params);
int llama_seqrep_check_word(
const struct llama_context * ctx,
const llama_token token,
std::vector<char> & buf);
size_t llama_seqrep_handle_rewind(
struct llama_context * ctx,
struct seqrep_rewind_state & rewind_state,
const std::vector<llama_token> & generated_tokens,
const size_t n_generated,
const std::vector<llama_token> & prompt_tokens,
const std::vector<llama_sampler_seqrep_params> & params_list,
size_t * high_water_mark,
const int32_t ith = 0);
void seqrep_sampler_help();
void seqrep_sampler_params_init(llama_sampler_seqrep_params * params);
void seqrep_sampler_params_dump(const llama_sampler_seqrep_params * params);
bool seqrep_sampler_params_parse(char * s, llama_sampler_seqrep_params * params);
struct llama_sampler_seqrep_params llama_seqrep_merge_params(
const std::vector<llama_sampler_seqrep_params> & params_list,
const int and_flags,
const int not_flags);

View file

@ -30,6 +30,7 @@ else()
add_subdirectory(quantize-stats) add_subdirectory(quantize-stats)
add_subdirectory(save-load-state) add_subdirectory(save-load-state)
add_subdirectory(simple) add_subdirectory(simple)
add_subdirectory(simple-inference)
add_subdirectory(speculative) add_subdirectory(speculative)
add_subdirectory(train-text-from-scratch) add_subdirectory(train-text-from-scratch)
if (LLAMA_METAL) if (LLAMA_METAL)

View file

@ -415,6 +415,11 @@ int main(int argc, char ** argv) {
} }
} }
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
#ifndef LLAMA_NO_SEQREP_SAMPLER
for (auto & sr_params : sparams.seqrep_params) {
seqrep_sampler_params_dump(&sr_params);
}
#endif
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n"); LOG_TEE("\n\n");

View file

@ -0,0 +1,8 @@
set(TARGET simple-inference)
add_executable(${TARGET} simple-inference.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

View file

@ -0,0 +1,819 @@
// Defines sigaction on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "common.h"
#include "console.h"
#include "llama.h"
#include "build-info.h"
#include "grammar-parser.h"
#include <algorithm>
#include <atomic>
#include <limits>
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#define SI_DUMP_SEQUENCES_INTERVAL 40
static std::atomic<bool> interrupted {false};
static std::atomic<bool> done {false};
typedef struct tokens_chunk {
bool is_input;
size_t consumed;
std::vector<llama_token> tokens;
tokens_chunk(const bool is_input = false, const size_t consumed = 0, const std::vector<llama_token> & tokens = {})
: is_input(is_input)
, consumed(consumed)
, tokens(tokens)
{}
} tokens_chunk;
enum seq_state {
SEQ_GENERATING,
SEQ_SHARE_PROMPT,
SEQ_INPUT,
SEQ_DONE,
};
typedef struct seq_ctx {
llama_seq_id seq_id;
int32_t batch_idx;
enum seq_state state;
size_t n_remain;
size_t n_generated;
llama_sampling_context *ctx_sampling;
llama_token last_sampled;
std::vector<tokens_chunk> chunks;
// std::vector<llama_token> pending_input;
// std::vector<llama_token> output;
#ifndef LLAMA_NO_SEQREP_SAMPLER
size_t high_water_mark;
struct seqrep_rewind_state rewind_state;
size_t rewind_count;
size_t rewind_tokens;
#endif
} seq_ctx;
typedef struct gen_ctx {
llama_context * ctx = nullptr;
llama_model * model = nullptr;
llama_sampling_context * ctx_sampling = nullptr;
llama_batch batch;
gpt_params params;
llama_sampling_params & sparams = params.sparams;
int n_ctx;
int n_vocab;
std::vector<llama_token> scratch;
std::vector<llama_token> prompt_tokens;
size_t prompt_size = 0;
llama_seq_id focused_sequence = 0;
size_t decode_count = 0;
int64_t decode_time_total = 0, decode_time_last = 0;
std::vector<seq_ctx> ctxs_seq;
private:
bool init_params(const int argc, char ** argv);
bool init_model();
bool init_prompt();
bool init_handlers();
bool init_sampling();
bool init_batches();
public:
gen_ctx(const int argc, char ** argv);
~gen_ctx();
void dump_batches(const size_t prompt_start = 0);
void dump_chunks(const std::vector<tokens_chunk> & chunks, const size_t start_offset = 0);
void dump_batch(const size_t seq);
void handle_seq(seq_ctx & sctx);
#ifndef LLAMA_NO_SEQREP_SAMPLER
void handle_seq_seqrep(seq_ctx & sctx);
#endif
bool feed_prompt(
const std::vector<llama_token> & tokens,
llama_pos pos = 0,
llama_seq_id seq = 0);
bool go();
} gen_ctx;
static void concat_chunks(const std::vector<tokens_chunk> & chunks, std::vector<llama_token> & dst, const size_t start_offset) {
size_t offset = 0;
for (const tokens_chunk & chunk : chunks) {
if (offset + chunk.tokens.size() <= start_offset) {
offset += chunk.tokens.size();
continue;
}
const size_t chunk_offset = start_offset - offset;
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
const llama_token * tp = chunk.tokens.data() + chunk_offset;
for (size_t i = 0; i < chunk_size; i++, tp++) {
dst.push_back(*tp);
}
offset += chunk_size;
}
}
static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
const std::vector<llama_token> input_tokens, const std::string output, const std::vector<llama_token> output_tokens) {
if (params.logdir.empty()) {
return;
}
const std::string timestamp = get_sortable_timestamp();
const bool success = create_directory_with_parents(params.logdir);
if (!success) {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
return;
}
const std::string logfile_path = params.logdir + timestamp + ".yml";
FILE * logfile = fopen(logfile_path.c_str(), "w");
if (logfile == NULL) {
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
return;
}
fprintf(logfile, "binary: simple-inference\n");
char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc);
fprintf(logfile, "\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "# Generation Results #\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "\n");
dump_string_yaml_multiline(logfile, "output", output.c_str());
dump_vector_int_yaml(logfile, "output_tokens", output_tokens);
llama_dump_timing_info_yaml(logfile, ctx);
fclose(logfile);
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
if (interrupted) {
done.store(true);
} else {
interrupted.store(true);
}
}
}
#endif
static bool check_unsupported(const gpt_params * params) {
std::string nope;
const llama_sampling_params & sparams = params->sparams;
if (params->embedding)
nope = "embedding";
else if (!sparams.grammar.empty())
nope = "grammar"; // Currently broken most likely
else if (sparams.cfg_scale != 1.0f)
nope = "cfg_scale";
else if (!sparams.cfg_negative_prompt.empty())
nope = "cfg_negative_prompt";
else if (!params->path_prompt_cache.empty())
nope = "prompt cache";
else if (params->escape)
nope = "prompt escaping";
else if (params->interactive || params->interactive_first || params->instruct)
nope = "interactive mode";
else if (!params->input_prefix.empty() || !params->input_suffix.empty() || params->input_prefix_bos)
nope = "input prefix or suffix";
else if (params->hellaswag)
nope = "hellaswag";
else if (params->n_keep != 0)
nope = "keep";
else if (!params->antiprompt.empty())
nope = "reverse prompt";
if (!nope.empty()) {
LOG_TEE("%s: error: We don't support %s here.\n", __func__, nope.c_str());
return false;
}
return true;
}
bool gen_ctx::init_params(const int argc, char ** argv) {
if (gpt_params_parse(argc, argv, params) == false) {
return false;
}
if (!check_unsupported(&params)) {
return false;
}
if (params.rope_freq_base != 10000.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
}
if (params.rope_freq_scale != 1.0) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
}
if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}
return true;
}
bool gen_ctx::init_model() {
LOG("%s: llama backend init\n", __func__);
llama_backend_init(params.numa);
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
return false;
}
// print system information
{
LOG_TEE("\n");
LOG_TEE("system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
n_ctx = llama_n_ctx(ctx);
n_vocab = llama_n_vocab(llama_get_model(ctx));
return true;
}
bool gen_ctx::init_prompt() {
const bool add_bos = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
LOG("add_bos: %d\n", add_bos);
if (!params.prompt.empty()) {
LOG("tokenize the prompt\n");
prompt_tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
}
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, prompt_tokens).c_str());
// Should not run without any tokens
if (prompt_tokens.empty()) {
prompt_tokens.push_back(llama_token_bos(model));
LOG("input was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, prompt_tokens).c_str());
}
LOG("n_ctx: %d\n", n_ctx);
if ((int) prompt_tokens.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) prompt_tokens.size(), n_ctx - 4);
return false;
}
prompt_size = prompt_tokens.size();
if (params.verbose_prompt) {
LOG_TEE("\n");
LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, prompt_tokens.size());
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
LOG_TEE("%6d -> '%s'\n", prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
LOG_TEE("\n");
}
return true;
}
bool gen_ctx::init_handlers() {
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
return true;
}
bool gen_ctx::init_sampling() {
LOG_TEE("sampling: %s\n", llama_sampling_print(sparams).c_str());
#ifndef LLAMA_NO_SEQREP_SAMPLER
for (auto & sr_params : sparams.seqrep_params) {
seqrep_sampler_params_dump(&sr_params);
}
#endif
ctx_sampling = llama_sampling_init(sparams);
return true;
}
bool gen_ctx::init_batches() {
batch = llama_batch_init(std::max(int32_t(prompt_size), params.n_batch), 0, 1);
int n_remain = params.n_predict;
if (n_remain < 0 || n_remain + int(prompt_size) > n_ctx) {
n_remain = n_ctx - prompt_size;
}
ctxs_seq.reserve(params.n_parallel);
for (int32_t i = 0; i < params.n_parallel; i++) {
seq_ctx && bs = {
llama_seq_id(i),
-1,
i == 0 ? SEQ_INPUT : SEQ_SHARE_PROMPT,
size_t(n_remain),
0,
llama_sampling_init(params.sparams),
-1,
{},
#ifndef LLAMA_NO_SEQREP_SAMPLER
prompt_size + 1,
seqrep_rewind_state(n_vocab, n_ctx, 2000),
0,
0,
#endif
};
GGML_ASSERT(prompt_size > 0);
bs.chunks.emplace_back(true, 0, prompt_tokens);
if (i > 0) {
bs.chunks.emplace_back(false, 0, std::vector<llama_token>());
}
#ifndef LLAMA_NO_SEQREP_SAMPLER
seqrep_rewind_slot & rw_slot = bs.rewind_state.get_rewind_slot(0);
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
// llama_sampling_cp(bs.ctx_sampling, rw_slot.ctx_sampling);
// bs.rewind_state.set_logits_slot(ctx, 0, (prompt_size - 1) % params.n_batch);
#endif
ctxs_seq.push_back(bs);
// if (i > 0) llama_kv_cache_seq_cp(ctx, 0, i, 0, prompt_size);
}
return true;
}
gen_ctx::gen_ctx(const int argc, char ** argv) {
bool result = true;
result = result && init_params(argc, argv);
result = result && init_model();
result = result && init_prompt();
result = result && init_handlers();
result = result && init_sampling();
result = result && init_batches();
if (!result) {
throw std::runtime_error("Initialization failed");
}
}
gen_ctx::~gen_ctx() {
for (auto & sctx : ctxs_seq) {
llama_sampling_free(sctx.ctx_sampling);
#ifndef LLAMA_NO_SEQREP_SAMPLER
for (auto & rs : sctx.rewind_state.rewind_slots) {
if (rs.ctx_sampling != nullptr) {
llama_sampling_free(rs.ctx_sampling);
rs.ctx_sampling = nullptr;
}
}
#endif
}
llama_sampling_free(ctx_sampling);
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
}
bool gen_ctx::feed_prompt(const std::vector<llama_token> & tokens, llama_pos pos, llama_seq_id seq) {
int32_t tokens_remain = tokens.size();
const llama_token * tokens_curr = tokens.data();
console::set_display(console::prompt);
while (tokens_remain > 0 && !interrupted) {
const int32_t chunk_size = std::min(int32_t(tokens_remain), params.n_batch);
llama_batch_clear(batch);
for (int i = 0; i < chunk_size; i++) {
llama_batch_add(batch, tokens_curr[i], pos + i, {seq}, false);
}
pos += batch.n_tokens;
tokens_remain -= batch.n_tokens;
batch.logits[batch.n_tokens - 1] = tokens_remain < 1;
if (llama_decode(ctx, batch) != 0) {
console::set_display(console::reset);
LOG_TEE("%s : failed to eval\n", __func__);
return false;
}
decode_count++;
// display text
for (int i = 0; i < batch.n_tokens; i++) {
const std::string token_str = llama_token_to_piece(ctx, tokens_curr[i]);
fputs(token_str.c_str(), stdout);
}
fflush(stdout);
tokens_curr += batch.n_tokens;
}
console::set_display(console::reset);
return true;
}
void gen_ctx::dump_chunks(const std::vector<tokens_chunk> & chunks, const size_t start_offset) {
size_t offset = 0;
bool prompt_mode = false;
console::set_display(console::reset);
for (const tokens_chunk & chunk : chunks) {
if (offset + chunk.tokens.size() < start_offset) {
offset += chunk.tokens.size();
continue;
}
const size_t chunk_offset = start_offset - offset;
const size_t chunk_size = chunk.tokens.size() - chunk_offset;
const llama_token * tp = chunk.tokens.data() + chunk_offset;
if (chunk.is_input != prompt_mode) {
prompt_mode = chunk.is_input;
console::set_display(prompt_mode ? console::prompt : console::reset);
}
for (size_t i = 0; i < chunk_size; i++, tp++) {
const std::string token_str = llama_token_to_piece(ctx, *tp);
fputs(token_str.c_str(), stdout);
}
}
if (prompt_mode) {
console::set_display(console::reset);
}
fflush(stdout);
}
void gen_ctx::dump_batches(const size_t prompt_start) {
bool first = true;
for (seq_ctx & sctx : ctxs_seq) {
if (sctx.seq_id == focused_sequence) continue;
printf("\n\n%s Result #%d (size: %zu",
!first ? "====================" : "####################",
sctx.seq_id + 1, prompt_size + sctx.n_generated);
#ifndef LLAMA_NO_SEQREP_SAMPLER
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
#endif
puts("):");
dump_chunks(sctx.chunks, prompt_start);
first = false;
}
seq_ctx & sctx = ctxs_seq[focused_sequence];
printf("\n\n%s Result #%d (size: %zu",
!first ? "====================" : "####################",
sctx.seq_id + 1, prompt_size + sctx.n_generated);
#ifndef LLAMA_NO_SEQREP_SAMPLER
printf(", rewind cnt/toks: %zu/%zu", sctx.rewind_count, sctx.rewind_tokens);
#endif
puts("):");
dump_chunks(sctx.chunks, prompt_start);
}
void gen_ctx::handle_seq(seq_ctx & sctx) {
switch (sctx.state) {
case SEQ_DONE:
case SEQ_SHARE_PROMPT: break;
case SEQ_GENERATING: {
GGML_ASSERT(sctx.batch_idx >= 0);
scratch.resize(prompt_size);
concat_chunks(sctx.chunks, scratch, prompt_size);
#ifndef LLAMA_NO_SEQREP_SAMPLER
handle_seq_seqrep(sctx);
#endif
sctx.last_sampled = llama_sampling_sample(ctx_sampling, ctx, NULL, sctx.batch_idx, scratch);
llama_sampling_accept(sctx.ctx_sampling, ctx, sctx.last_sampled, true);
if (sctx.seq_id == focused_sequence) {
const std::string token_str = llama_token_to_piece(ctx, sctx.last_sampled);
fputs(token_str.c_str(), stdout);
fflush(stdout);
}
sctx.n_generated++;
sctx.n_remain--;
if (sctx.chunks.empty() || sctx.chunks.back().is_input) {
sctx.chunks.emplace_back(0, false, std::vector<llama_token>());
}
sctx.chunks.back().tokens.push_back(sctx.last_sampled);
if (sctx.last_sampled == llama_token_eos(model) || sctx.n_remain == 0) {
sctx.state = SEQ_DONE;
sctx.batch_idx = -1;
// LOG_TEE(" [end of text]\n");
// break;
} else {
sctx.batch_idx = batch.n_tokens;
llama_batch_add(batch, sctx.last_sampled, prompt_size + sctx.n_generated, {sctx.seq_id}, true);
}
} break;
case SEQ_INPUT: {
sctx.last_sampled = -1;
GGML_ASSERT(!sctx.chunks.empty());
tokens_chunk & chunk = sctx.chunks.back();
GGML_ASSERT(chunk.is_input);
const size_t remain = chunk.tokens.size() - chunk.consumed;
const size_t to_consume = std::min(size_t(params.n_batch), remain);
for (size_t i = chunk.consumed; i < chunk.consumed + to_consume; ++i) {
llama_batch_add(batch, chunk.tokens[i], llama_pos(i), {sctx.seq_id}, false);
}
chunk.consumed += to_consume;
if (chunk.consumed == chunk.tokens.size()) {
sctx.batch_idx = batch.n_tokens - 1;
batch.logits[sctx.batch_idx] = true;
} else {
sctx.batch_idx = -1;
}
} break;
default:
throw std::runtime_error("Unexpected state in handle_seq");
}
}
#ifndef LLAMA_NO_SEQREP_SAMPLER
void gen_ctx::handle_seq_seqrep(seq_ctx & sctx) {
if (sctx.n_generated > 0) {
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(sctx.n_generated);
if (rw_slot.ctx_sampling == nullptr) {
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
}
llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling);
sctx.rewind_state.set_logits_slot(ctx, sctx.n_generated, sctx.batch_idx);
} else {
return;
}
std::vector<llama_token> seq_last_tokens;
seq_last_tokens.reserve(sctx.n_generated);
concat_chunks(sctx.chunks, seq_last_tokens, prompt_size);
size_t rewind_distance =
llama_seqrep_handle_rewind(
ctx, sctx.rewind_state, seq_last_tokens, sctx.n_generated, prompt_tokens,
sparams.seqrep_params, &sctx.high_water_mark, sctx.batch_idx);
if (rewind_distance < 1) {
return;
}
// if (sctx.seq_id != 0) printf("<%d:%zu>", sctx.seq_id + 1, rewind_distance);
GGML_ASSERT(rewind_distance <= sctx.n_generated && "Rewind index out of bounds somehow?");
const size_t slot_idx = sctx.n_generated - rewind_distance;
const llama_token nl_id = llama_token_nl(model);
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(slot_idx);
llama_sampling_cp(rw_slot.ctx_sampling, sctx.ctx_sampling);
if (sctx.seq_id == focused_sequence) {
console::set_display(console::error);
fputs("\u3010", stdout);
// printf("%zu,%zu,%zu", rewind_distance, sctx.n_generated, sctx.generated_tokens.size());
for (size_t i = seq_last_tokens.size() - rewind_distance; i < seq_last_tokens.size(); i++) {
if (seq_last_tokens[i] == nl_id) {
fputs("\\n", stdout);
continue;
}
const std::string token_str = llama_token_to_piece(ctx, seq_last_tokens[i]);
// fputs("|", stdout);
fputs(token_str.c_str(), stdout);
}
fputs("\u3011", stdout);
console::set_display(console::reset);
fflush(stdout);
}
sctx.n_remain += rewind_distance;
sctx.n_generated -= rewind_distance;
sctx.rewind_count++;
sctx.rewind_tokens += rewind_distance;
llama_kv_cache_seq_rm(ctx, sctx.seq_id, prompt_size + sctx.n_generated + 1, -1);
while (!sctx.chunks.empty() && rewind_distance > 0) {
tokens_chunk & last_chunk = sctx.chunks.back();
GGML_ASSERT(!last_chunk.is_input);
if (last_chunk.tokens.size() >= rewind_distance) {
last_chunk.tokens.resize(last_chunk.tokens.size() - rewind_distance);
rewind_distance = 0;
break;
}
rewind_distance -= last_chunk.tokens.size();
sctx.chunks.pop_back();
}
}
#endif
bool gen_ctx::go() {
if (ctxs_seq.empty()) {
return false;
}
if (decode_count == 0) {
scratch.reserve(n_ctx);
scratch.resize(prompt_size);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), scratch.begin());
// FIXME: Hacky.
if (!feed_prompt(prompt_tokens)) {
throw std::runtime_error("Prompt processing failed");
}
for (auto & sctx : ctxs_seq) {
sctx.batch_idx = batch.n_tokens - 1;
sctx.state = SEQ_GENERATING;
if (sctx.seq_id == 0) {
sctx.chunks.emplace_back(false, 0, std::vector<llama_token>());
} else {
llama_kv_cache_seq_cp(ctx, 0, sctx.seq_id, 0, prompt_size);
}
#ifndef LLAMA_NO_SEQREP_SAMPLER
seqrep_rewind_slot & rw_slot = sctx.rewind_state.get_rewind_slot(0);
rw_slot.ctx_sampling = llama_sampling_init(params.sparams);
llama_sampling_cp(sctx.ctx_sampling, rw_slot.ctx_sampling);
sctx.rewind_state.set_logits_slot(ctx, 0, sctx.batch_idx);
#endif
}
}
llama_batch_clear(batch);
for (auto & sctx : ctxs_seq) {
handle_seq(sctx);
}
if (batch.n_tokens == 0) return false;
decode_time_last = ggml_time_us();
const int decode_result = llama_decode(ctx, batch);
decode_time_last = std::max(int64_t(0), ggml_time_us() - decode_time_last);
decode_time_total += decode_time_last;
if (decode_result != 0) {
LOG_TEE("%s : failed to eval batch of size %d: %s\n", __func__, batch.n_tokens,
decode_result == 1 ? "couldn't find slot" : "unknown error");
return false;
}
decode_count++;
return true;
}
static bool handle_commands(gen_ctx & gctx) {
std::string line;
line.reserve(1024);
puts("");
fflush(stdout);
while (1) {
printf("> ");
console::readline(line, false);
console::set_display(console::reset);
while (!line.empty() && std::isspace(line.back())) {
line.pop_back();
}
if (line.empty()) break;
if (line.size() < 2 || line.front() != '/') {
printf("\n* Bad command\n");
continue;
}
size_t sep_idx = line.find(' ');
std::string command, rest;
if (sep_idx != std::string::npos) {
command = line.substr(1, sep_idx - 1);
rest = line.substr(sep_idx + 1);
} else {
command = line.substr(1);
}
if (command == "quit") return false;
// Focus
if (isdigit(command[0])) {
const int target = std::atoi(command.c_str());
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
printf("\n* Focus: Bad seq id\n");
} else {
gctx.focused_sequence = llama_seq_id(target - 1);
}
continue;
}
if (command == "kill") {
const int target = std::atoi(rest.c_str());
if (target < 1 || size_t(target) > gctx.ctxs_seq.size()) {
printf("\n* Kill: Bad seq id\n");
} else if (target - 1 == gctx.focused_sequence) {
printf("\n* Kill: Can't kill focus\n");
} else {
gctx.ctxs_seq[target - 1].state = SEQ_DONE;
}
continue;
}
printf("\n* Bad command\n");
}
return true;
}
int main(int argc, char ** argv) {
gen_ctx gctx(argc, argv);
while (gctx.go() && !done) {
bool need_dump = gctx.params.n_parallel > 1 && gctx.decode_count % SI_DUMP_SEQUENCES_INTERVAL == 0;
if (interrupted) {
if (!handle_commands(gctx)) break;
if (done) break;
interrupted = false;
need_dump = true;
}
if (need_dump) {
printf("\n-- Last decode[%zu]: %.3f, avg: %.3f",
gctx.decode_count, double(gctx.decode_time_last) / 1000000,
(double(gctx.decode_time_total) / 1000000) / double(gctx.decode_count));
gctx.dump_batches((gctx.prompt_size > 20) ? (gctx.prompt_size - 10) : 0);
}
}
gctx.focused_sequence = gctx.ctxs_seq.size() - 1;
gctx.dump_batches();
puts("");
console::cleanup();
llama_print_timings(gctx.ctx);
}

View file

@ -1,10 +1,14 @@
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#ifndef LLAMA_NO_SEQREP_SAMPLER
#include "common/seqrep-sampler.h"
#endif
#ifdef NDEBUG #ifdef NDEBUG
#undef NDEBUG #undef NDEBUG
#endif #endif
#include <cstring>
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
#include <cassert> #include <cassert>
@ -128,6 +132,79 @@ static void test_repetition_penalties(
} }
} }
// FIXME: This should probably just be moved to a separate test executable.
#ifndef LLAMA_NO_SEQREP_SAMPLER
// NOTE: Compares expected_probs at id position, not sorted position like the other
// test functions.
static void test_seqrep_penalty(
const std::vector<float> & probs,
const std::vector<llama_token> & last_tokens,
const std::vector<float> & expected_probs,
const llama_sampler_seqrep_params * params) {
assert(probs.size() == expected_probs.size());
size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
float logit = log(probs[token_id]);
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p);
llama_sample_seqrep_penalty(nullptr, &candidates_p, last_tokens, params);
llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) {
assert(fabs(candidates_p.data[i].p - expected_probs[candidates_p.data[i].id]) < 1e-3);
}
}
static void run_seqrep_tests(void) {
llama_sampler_seqrep_params params;
// Compatible with frequency/presence penalty
memset(&params, 0, sizeof(llama_sampler_seqrep_params));
params.last_n = 1024;
params.min_length = 1;
params.mid_word_scale = 1.0f;
params.presence_penalty = 5.0f;
params.length_penalty = 5.0f;
params.flags |= LLAMA_SEQREP_ABSOLUTE_PENALTY;
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.000011f, 0.249997f, 0.249997f, 0.249997f, 0.249997f}, &params);
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.000023f, 0.000023f, 0.000023f, 0.499966f, 0.499966f}, &params);
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.000000f, 0.000023f, 0.000023f, 0.499977f, 0.499977f}, &params);
// Compatible with repetition penalty
memset(&params, 0, sizeof(llama_sampler_seqrep_params));
params.last_n = 1024;
params.min_length = 1;
params.mid_word_scale = 1.0f;
params.presence_penalty = 50.0f;
params.length_penalty = 1.0f;
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0, 0.25f, 0.25f, 0.25f, 0.25f}, &params);
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0, 0, 0, 0.5f, 0.5f}, &params);
test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0, 0, 0, 0.5f, 0.5f}, &params);
// Seqrep mode
// memset(&params, 0, sizeof(llama_sampler_seqrep_params));
// params.last_n = 1024;
// params.min_length = 3;
// params.mid_word_scale = 1.0f;
// params.presence_penalty = 50.0f;
// params.length_penalty = 1.0f;
// test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, &params);
// test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.20f, 0.20f, 0.20f, 0.20f, 0.20f}, &params);
// params.tolerance = 1.0f;
// test_seqrep_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 2, 2, 3, 0, 1, 2}, {0.25f, 0.25f, 0.25f, 0, 0.25f}, &params);
}
#endif
int main(void) { int main(void) {
ggml_time_init(); ggml_time_init();
@ -154,6 +231,10 @@ int main(void) {
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
#ifndef LLAMA_NO_SEQREP_SAMPLER
run_seqrep_tests();
#endif
printf("OK\n"); printf("OK\n");
return 0; return 0;