fix compile errors, rwkv not working
This commit is contained in:
parent
15576bc865
commit
2827920044
7 changed files with 13 additions and 19 deletions
|
@ -563,7 +563,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
rwkv_ctx_v3->logits_out = (float *)malloc(logitbufsiz);
|
rwkv_ctx_v3->logits_out = (float *)malloc(logitbufsiz);
|
||||||
rwkv_ctx_v3->state_in = nullptr;
|
rwkv_ctx_v3->state_in = nullptr;
|
||||||
|
|
||||||
bool testeval = rwkv_eval(rwkv_ctx_v3, 0, rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
bool testeval = rwkv_eval(rwkv_ctx_v3, params.n_threads, 0, rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||||
if (!testeval)
|
if (!testeval)
|
||||||
{
|
{
|
||||||
printf("\nError: RWKV Init Eval Failed!\n");
|
printf("\nError: RWKV Init Eval Failed!\n");
|
||||||
|
@ -1162,12 +1162,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
{
|
{
|
||||||
if(embd.size()>1)
|
if(embd.size()>1)
|
||||||
{
|
{
|
||||||
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
evalres = rwkv_eval_sequence(rwkv_ctx_v3, params.n_threads, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
|
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
|
||||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
|
evalres = rwkv_eval(rwkv_ctx_v3, params.n_threads, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||||
|
|
|
@ -447,7 +447,6 @@ bool gpt2_eval(
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph gf = {};
|
struct ggml_cgraph gf = {};
|
||||||
gf.n_threads = n_threads;
|
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||||
|
@ -708,7 +707,7 @@ bool gpt2_eval(
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute (ctx0, &gf);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_print (&gf);
|
// ggml_graph_print (&gf);
|
||||||
|
|
|
@ -445,7 +445,6 @@ bool gptj_eval(
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph gf = {};
|
struct ggml_cgraph gf = {};
|
||||||
gf.n_threads = n_threads;
|
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||||
|
@ -620,7 +619,7 @@ bool gptj_eval(
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute (ctx0, &gf);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_print (&gf);
|
// ggml_graph_print (&gf);
|
||||||
|
|
|
@ -383,7 +383,6 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph gf = {};
|
struct ggml_cgraph gf = {};
|
||||||
gf.n_threads = n_threads;
|
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
|
memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
|
||||||
|
@ -543,7 +542,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute(ctx0, &gf);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
// std::cout << "Qcur" << std::endl;
|
// std::cout << "Qcur" << std::endl;
|
||||||
// print_tensor(Qcur);
|
// print_tensor(Qcur);
|
||||||
|
|
|
@ -461,7 +461,6 @@ bool gpt_neox_eval(
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph gf = {};
|
struct ggml_cgraph gf = {};
|
||||||
gf.n_threads = n_threads;
|
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||||
|
@ -639,7 +638,7 @@ bool gpt_neox_eval(
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute (ctx0, &gf);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_print (&gf);
|
// ggml_graph_print (&gf);
|
||||||
|
|
|
@ -1511,7 +1511,6 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr<struct rwkv_instance
|
||||||
serial_graph.tokens = ggml_new_i32(serial_graph.ctx.ctx, 0);
|
serial_graph.tokens = ggml_new_i32(serial_graph.ctx.ctx, 0);
|
||||||
serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
|
serial_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
|
||||||
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph");
|
RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, serial_graph.cgraph, "Failed to allocate serial graph");
|
||||||
serial_graph.cgraph->n_threads = n_threads;
|
|
||||||
|
|
||||||
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(
|
RWKV_ASSERT_NULL(RWKV_ERROR_GRAPH, rwkv_build_serial_graph(
|
||||||
serial_graph.ctx.ctx, instance->model,
|
serial_graph.ctx.ctx, instance->model,
|
||||||
|
@ -1609,7 +1608,7 @@ void rwkv_get_outputs(const struct rwkv_context * ctx, float * state_out, float
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
|
bool rwkv_eval(struct rwkv_context * ctx, const int n_threads, const uint32_t token, const float * state_in, float * state_out, float * logits_out) {
|
||||||
ctx->last_error = RWKV_ERROR_NONE;
|
ctx->last_error = RWKV_ERROR_NONE;
|
||||||
|
|
||||||
const struct rwkv_file_header & header = ctx->instance->model.header;
|
const struct rwkv_file_header & header = ctx->instance->model.header;
|
||||||
|
@ -1628,13 +1627,13 @@ bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * st
|
||||||
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
|
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get());
|
ggml_graph_compute_with_ctx(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get(),n_threads);
|
||||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, const size_t sequence_len, const float * state_in, float * state_out, float * logits_out) {
|
bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const uint32_t * sequence, const size_t sequence_len, const float * state_in, float * state_out, float * logits_out) {
|
||||||
ctx->last_error = RWKV_ERROR_NONE;
|
ctx->last_error = RWKV_ERROR_NONE;
|
||||||
|
|
||||||
const struct rwkv_file_header & header = ctx->instance->model.header;
|
const struct rwkv_file_header & header = ctx->instance->model.header;
|
||||||
|
@ -1690,7 +1689,6 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
|
||||||
sequence_graph.tokens = ggml_new_tensor_1d(sequence_graph.ctx.ctx, GGML_TYPE_I32, sequence_len);
|
sequence_graph.tokens = ggml_new_tensor_1d(sequence_graph.ctx.ctx, GGML_TYPE_I32, sequence_len);
|
||||||
sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
|
sequence_graph.cgraph.reset(new(std::nothrow) struct ggml_cgraph());
|
||||||
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph");
|
RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, sequence_graph.cgraph, "Failed to allocate sequence graph");
|
||||||
sequence_graph.cgraph->n_threads = 1;
|
|
||||||
|
|
||||||
RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(
|
RWKV_ASSERT_FALSE(RWKV_ERROR_GRAPH, rwkv_build_sequence_graph(
|
||||||
sequence_graph.ctx.ctx, ctx->instance->model,
|
sequence_graph.ctx.ctx, ctx->instance->model,
|
||||||
|
@ -1717,7 +1715,7 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * sequence, co
|
||||||
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
|
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get());
|
ggml_graph_compute_with_ctx(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get(),n_threads);
|
||||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,7 @@ extern "C" {
|
||||||
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
|
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
|
||||||
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
||||||
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
||||||
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
|
RWKV_API bool rwkv_eval(struct rwkv_context *, const int n_threads, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
|
||||||
|
|
||||||
// Evaluates the model for a sequence of tokens.
|
// Evaluates the model for a sequence of tokens.
|
||||||
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
|
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
|
||||||
|
@ -125,7 +125,7 @@ extern "C" {
|
||||||
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
|
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
|
||||||
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
|
||||||
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
|
||||||
RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
|
RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
|
||||||
|
|
||||||
// Returns the number of tokens in the given model's vocabulary.
|
// Returns the number of tokens in the given model's vocabulary.
|
||||||
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
|
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue