speculative : refactor sampling

This commit is contained in:
Georgi Gerganov 2023-10-15 22:30:59 +03:00
parent 32a67cbd16
commit 4a7f43f28c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 210 additions and 205 deletions

View file

@ -2,6 +2,8 @@
#include "llama.h"
#include "grammar-parser.h"
#include <string>
#include <vector>
#include <unordered_map>
@ -35,7 +37,8 @@ typedef struct llama_sampling_params {
} llama_sampling_params;
// general sampler context
typedef struct llama_sampling_context {
// TODO: move to llama.h
struct llama_sampling_context {
// parameters that will be used for sampling
llama_sampling_params params;
@ -43,45 +46,50 @@ typedef struct llama_sampling_context {
float mirostat_mu;
llama_grammar * grammar;
} llama_sampling_context;
// internal
grammar_parser::parse_state parsed_grammar;
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
};
#include "common.h"
// Create a new sampling context instance.
llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar = NULL);
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
// Reset the sampler context for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required.
bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);
void llama_sampling_free(struct llama_sampling_context * ctx);
// Reset the sampler context
// - clear prev tokens
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
// llama_sampling_context_reset when a sequence ends
// llama_sampling_reset when a sequence ends
//
// required:
// - ctx: context to use for sampling
// - ctx_main: context to use for sampling
// - ctx_sampling: sampling-specific context
//
// optional:
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx)
// - seq: sequence id to associate sampler state with
// - ctx_guidance: context to use for guidance
// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
struct llama_context * ctx,
struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx = 0);
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_guidance,
int idx = 0);
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id);