rwkv compile fix (+1 squashed commits)
Squashed commits: [8b0ebb1] upgraded rwkv + added memory overheads + added state_out bufs
This commit is contained in:
parent
120851df53
commit
860fb026df
2 changed files with 1279 additions and 660 deletions
File diff suppressed because it is too large
Load diff
|
@ -61,6 +61,10 @@ extern "C" {
|
|||
RWKV_ERROR_PARAM_MISSING = 14
|
||||
};
|
||||
|
||||
// RWKV context that can be used for inference.
|
||||
// All functions that operate on rwkv_context are thread-safe.
|
||||
// rwkv_context can be sent to different threads between calls to rwkv_eval.
|
||||
// There is no requirement for rwkv_context to be freed on the creating thread.
|
||||
struct rwkv_context;
|
||||
|
||||
// Sets whether errors are automatically printed to stderr.
|
||||
|
@ -85,14 +89,39 @@ extern "C" {
|
|||
// - n_threads: count of threads to use, must be positive.
|
||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
||||
|
||||
// Creates a new context from an existing one.
|
||||
// This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times.
|
||||
// Each rwkv_context can have one eval running at a time.
|
||||
// Every rwkv_context must be freed using rwkv_free.
|
||||
// - ctx: context to be cloned.
|
||||
// - n_threads: count of threads to use, must be positive.
|
||||
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
|
||||
|
||||
// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
|
||||
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
|
||||
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);
|
||||
|
||||
// Evaluates the model for a single token.
|
||||
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||
// Returns false on any error. Error messages would be printed to stderr.
|
||||
// - token: next token index, in range 0 <= token < n_vocab.
|
||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
|
||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
|
||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL.
|
||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL.
|
||||
RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
|
||||
|
||||
// Evaluates the model for a sequence of tokens.
|
||||
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
|
||||
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
|
||||
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed. (Useful for initialization.)
|
||||
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||
// Returns false on any error. Error messages would be printed to stderr.
|
||||
// - sequence_len: number of tokens to read from the array.
|
||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count, or NULL if this is a first pass.
|
||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL.
|
||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL.
|
||||
RWKV_API bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
|
||||
|
||||
// Returns count of FP32 elements in state buffer.
|
||||
RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx);
|
||||
|
||||
|
@ -100,6 +129,7 @@ extern "C" {
|
|||
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
|
||||
|
||||
// Frees all allocated memory and the context.
|
||||
// Does not need to be the same thread that created the rwkv_context.
|
||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||
|
||||
// Quantizes FP32 or FP16 model to one of quantized formats.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue