try out the new rwkv but it seems worse, may revert
This commit is contained in:
parent
632bf27b65
commit
e1a7042943
4 changed files with 825 additions and 370 deletions
|
@ -431,6 +431,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
else //rwkv_2
|
else //rwkv_2
|
||||||
{
|
{
|
||||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||||
|
|
||||||
|
if(inputs.gpulayers>0)
|
||||||
|
{
|
||||||
|
rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
|
||||||
|
}
|
||||||
|
|
||||||
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
|
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
|
||||||
const size_t n_vocab = header.n_vocab;
|
const size_t n_vocab = header.n_vocab;
|
||||||
printf("\nDetected Vocab: %d",n_vocab);
|
printf("\nDetected Vocab: %d",n_vocab);
|
||||||
|
@ -811,7 +817,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
params.top_k = 120; //to disable top_k we actually need to increase this value to a very high number
|
params.top_k = 120; //to disable top_k we actually need to increase this value to a very high number
|
||||||
}
|
}
|
||||||
if (params.seed <= 0)
|
if (params.seed <= 0 || params.seed==0xFFFFFFFF)
|
||||||
{
|
{
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
@ -1060,14 +1066,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if(embd.size()>1)
|
// if(embd.size()>1)
|
||||||
{
|
// {
|
||||||
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
// evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||||
}
|
// }
|
||||||
else
|
// else
|
||||||
{
|
// {
|
||||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
|
||||||
}
|
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
|
||||||
|
//}
|
||||||
|
|
||||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||||
|
|
|
@ -2204,7 +2204,7 @@ struct llama_v2_context * llama_v2_init_from_file(
|
||||||
|
|
||||||
llama_v2_context * ctx = new llama_v2_context;
|
llama_v2_context * ctx = new llama_v2_context;
|
||||||
|
|
||||||
if (params.seed < 0) {
|
if (params.seed < 0 || params.seed==0xFFFFFFFF) {
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2552,7 +2552,7 @@ int llama_v2_get_kv_cache_token_count(const struct llama_v2_context * ctx) {
|
||||||
#define LLAMA_V2_MAX_RNG_STATE (64*1024)
|
#define LLAMA_V2_MAX_RNG_STATE (64*1024)
|
||||||
|
|
||||||
void llama_v2_set_rng_seed(struct llama_v2_context * ctx, int seed) {
|
void llama_v2_set_rng_seed(struct llama_v2_context * ctx, int seed) {
|
||||||
if (seed < 0) {
|
if (seed < 0 || seed==0xFFFFFFFF) {
|
||||||
seed = time(NULL);
|
seed = time(NULL);
|
||||||
}
|
}
|
||||||
ctx->rng.seed(seed);
|
ctx->rng.seed(seed);
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -84,7 +84,7 @@ extern "C" {
|
||||||
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx);
|
RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx);
|
||||||
|
|
||||||
// Loads the model from a file and prepares it for inference.
|
// Loads the model from a file and prepares it for inference.
|
||||||
// Returns NULL on any error. Error messages would be printed to stderr.
|
// Returns NULL on any error.
|
||||||
// - model_file_path: path to model file in ggml format.
|
// - model_file_path: path to model file in ggml format.
|
||||||
// - n_threads: count of threads to use, must be positive.
|
// - n_threads: count of threads to use, must be positive.
|
||||||
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
|
||||||
|
@ -97,39 +97,64 @@ extern "C" {
|
||||||
// - n_threads: count of threads to use, must be positive.
|
// - n_threads: count of threads to use, must be positive.
|
||||||
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
|
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);
|
||||||
|
|
||||||
// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
|
// Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS.
|
||||||
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
|
// Returns true if at least one layer was offloaded.
|
||||||
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);
|
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op and always returns false.
|
||||||
|
RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers);
|
||||||
|
|
||||||
// Evaluates the model for a single token.
|
// Evaluates the model for a single token.
|
||||||
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||||
// Returns false on any error. Error messages would be printed to stderr.
|
// Returns false on any error.
|
||||||
|
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
|
||||||
|
// that you do not calculate logits.
|
||||||
// - token: next token index, in range 0 <= token < n_vocab.
|
// - token: next token index, in range 0 <= token < n_vocab.
|
||||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
|
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
|
||||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL.
|
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
||||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL.
|
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
||||||
RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
|
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
|
||||||
|
|
||||||
// Evaluates the model for a sequence of tokens.
|
// Evaluates the model for a sequence of tokens.
|
||||||
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
|
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
|
||||||
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
|
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
|
||||||
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed. (Useful for initialization.)
|
|
||||||
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
|
||||||
// Returns false on any error. Error messages would be printed to stderr.
|
// Returns false on any error.
|
||||||
|
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10ms per iteration
|
||||||
|
// that you do not calculate logits.
|
||||||
|
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
|
||||||
// - sequence_len: number of tokens to read from the array.
|
// - sequence_len: number of tokens to read from the array.
|
||||||
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count, or NULL if this is a first pass.
|
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
|
||||||
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to if non-NULL.
|
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
||||||
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to if non-NULL.
|
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
||||||
RWKV_API bool rwkv_eval_sequence(const struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
|
RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
|
||||||
|
|
||||||
// Returns count of FP32 elements in state buffer.
|
// Returns the number of tokens in the given model's vocabulary.
|
||||||
RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx);
|
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
|
||||||
|
RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);
|
||||||
|
|
||||||
// Returns count of FP32 elements in logits buffer.
|
// Returns the number of elements in the given model's embedding.
|
||||||
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
|
// Useful for reading individual fields of a model's hidden state.
|
||||||
|
RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx);
|
||||||
|
|
||||||
|
// Returns the number of layers in the given model.
|
||||||
|
// Useful for always offloading the entire model to GPU.
|
||||||
|
RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx);
|
||||||
|
|
||||||
|
// Returns the number of float elements in a complete state for the given model.
|
||||||
|
// This is the number of elements you'll need to allocate for a call to rwkv_eval, rwkv_eval_sequence, or rwkv_init_state.
|
||||||
|
RWKV_API size_t rwkv_get_state_len(const struct rwkv_context * ctx);
|
||||||
|
|
||||||
|
// Returns the number of float elements in the logits output of a given model.
|
||||||
|
// This is currently always identical to n_vocab.
|
||||||
|
RWKV_API size_t rwkv_get_logits_len(const struct rwkv_context * ctx);
|
||||||
|
|
||||||
|
// Initializes the given state so that passing it to rwkv_eval or rwkv_eval_sequence would be identical to passing NULL.
|
||||||
|
// Useful in cases where tracking the first call to these functions may be annoying or expensive.
|
||||||
|
// State must be initialized for behavior to be defined, passing a zeroed state to rwkv.cpp functions will result in NaNs.
|
||||||
|
// - state: FP32 buffer of size rwkv_get_state_len() to initialize
|
||||||
|
RWKV_API void rwkv_init_state(const struct rwkv_context * ctx, float * state);
|
||||||
|
|
||||||
// Frees all allocated memory and the context.
|
// Frees all allocated memory and the context.
|
||||||
// Does not need to be the same thread that created the rwkv_context.
|
// Does not need to be called on the same thread that created the rwkv_context.
|
||||||
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
RWKV_API void rwkv_free(struct rwkv_context * ctx);
|
||||||
|
|
||||||
// Quantizes FP32 or FP16 model to one of quantized formats.
|
// Quantizes FP32 or FP16 model to one of quantized formats.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue