Use ring buffer to store prev in sampling
This commit is contained in:
parent
48607c7a77
commit
3b23ea74e2
4 changed files with 110 additions and 8 deletions
|
@ -40,7 +40,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
|
||||||
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
|
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev = ring_buffer<llama_token>(params.n_prev);
|
||||||
|
|
||||||
result->n_valid = 0;
|
result->n_valid = 0;
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
|
||||||
void llama_sampling_reset(llama_sampling_context * ctx) {
|
void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||||
llama_sampling_reset(ctx->smpl);
|
llama_sampling_reset(ctx->smpl);
|
||||||
|
|
||||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
ctx->prev.clear();
|
||||||
ctx->cur.clear();
|
ctx->cur.clear();
|
||||||
ctx->n_valid = 0;
|
ctx->n_valid = 0;
|
||||||
}
|
}
|
||||||
|
@ -384,7 +384,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
// apply penalties
|
// apply penalties
|
||||||
const auto & penalty_tokens = prev;
|
const auto & penalty_tokens = prev.to_vector();
|
||||||
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||||
if (penalty_tokens_used_size) {
|
if (penalty_tokens_used_size) {
|
||||||
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||||
|
@ -434,7 +434,9 @@ void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
llama_token id,
|
llama_token id,
|
||||||
bool apply_grammar) {
|
bool apply_grammar) {
|
||||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
if (!ctx_sampling->prev.empty()) {
|
||||||
|
ctx_sampling->prev.pop_front();
|
||||||
|
}
|
||||||
ctx_sampling->prev.push_back(id);
|
ctx_sampling->prev.push_back(id);
|
||||||
|
|
||||||
if (apply_grammar) {
|
if (apply_grammar) {
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
// sampler types
|
// sampler types
|
||||||
enum class llama_sampler_type : char {
|
enum class llama_sampler_type : char {
|
||||||
|
@ -58,6 +59,106 @@ typedef struct gpt_sampling_params {
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
} gpt_sampling_params;
|
} gpt_sampling_params;
|
||||||
|
|
||||||
|
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||||
|
template<typename T>
|
||||||
|
struct ring_buffer {
|
||||||
|
ring_buffer() {}
|
||||||
|
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||||
|
|
||||||
|
T & front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & front() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
T & back() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & back() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back(const T & value) {
|
||||||
|
if (sz == capacity) {
|
||||||
|
// advance the start when buffer is full
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
} else {
|
||||||
|
sz++;
|
||||||
|
}
|
||||||
|
data[pos] = value;
|
||||||
|
pos = (pos + 1) % capacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
T pop_front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
T value = data[first];
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
sz--;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
T & operator[](size_t i) {
|
||||||
|
if (i >= sz) {
|
||||||
|
throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
}
|
||||||
|
return data[(first + i) % capacity];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & operator[](size_t i) const {
|
||||||
|
if (i >= sz) {
|
||||||
|
throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
}
|
||||||
|
return data[(first + i) % capacity];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<T> to_vector() const {
|
||||||
|
std::vector<T> result;
|
||||||
|
result.reserve(sz);
|
||||||
|
for (size_t i = 0; i < sz; i++) {
|
||||||
|
result.push_back(data[(first + i) % capacity]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
// here only reset the status of the buffer
|
||||||
|
sz = 0;
|
||||||
|
first = 0;
|
||||||
|
pos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return sz == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return sz;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t capacity = 0;
|
||||||
|
size_t sz = 0;
|
||||||
|
size_t first = 0;
|
||||||
|
size_t pos = 0;
|
||||||
|
std::vector<T> data;
|
||||||
|
};
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
// TODO: move to llama.h
|
// TODO: move to llama.h
|
||||||
struct llama_sampling_context {
|
struct llama_sampling_context {
|
||||||
|
@ -69,8 +170,7 @@ struct llama_sampling_context {
|
||||||
|
|
||||||
llama_sampling * smpl;
|
llama_sampling * smpl;
|
||||||
|
|
||||||
// TODO: replace with ring-buffer
|
ring_buffer<llama_token> prev;
|
||||||
std::vector<llama_token> prev;
|
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
|
|
||||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||||
|
|
|
@ -421,7 +421,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, id, true);
|
llama_sampling_accept(ctx_sampling, id, true);
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
|
||||||
|
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
|
|
|
@ -733,7 +733,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, id, /* apply_grammar= */ true);
|
llama_sampling_accept(ctx_sampling, id, /* apply_grammar= */ true);
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
|
||||||
|
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue