fixed rwkv, standardized new ctx usage

This commit is contained in:
Concedo 2023-07-10 20:05:53 +08:00
parent 2827920044
commit 523fc3be52
8 changed files with 27 additions and 8 deletions

4
.gitignore vendored
View file

@ -67,4 +67,6 @@ koboldcpp_failsafe.dll
koboldcpp_openblas.dll koboldcpp_openblas.dll
koboldcpp_openblas_noavx2.dll koboldcpp_openblas_noavx2.dll
koboldcpp_clblast.dll koboldcpp_clblast.dll
koboldcpp_cublas.dll koboldcpp_cublas.dll
cublas64_11.dll
cublasLt64_11.dll

View file

@ -707,7 +707,7 @@ bool gpt2_eval(
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); kcpp_graph_compute_helper(&gf, n_threads);
//if (n_past%100 == 0) { //if (n_past%100 == 0) {
// ggml_graph_print (&gf); // ggml_graph_print (&gf);

View file

@ -619,7 +619,7 @@ bool gptj_eval(
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); kcpp_graph_compute_helper(&gf, n_threads);
//if (n_past%100 == 0) { //if (n_past%100 == 0) {
// ggml_graph_print (&gf); // ggml_graph_print (&gf);

View file

@ -542,7 +542,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); kcpp_graph_compute_helper(&gf, n_threads);
// std::cout << "Qcur" << std::endl; // std::cout << "Qcur" << std::endl;
// print_tensor(Qcur); // print_tensor(Qcur);

View file

@ -638,7 +638,7 @@ bool gpt_neox_eval(
// run the computation // run the computation
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); kcpp_graph_compute_helper(&gf, n_threads);
//if (n_past%100 == 0) { //if (n_past%100 == 0) {
// ggml_graph_print (&gf); // ggml_graph_print (&gf);

View file

@ -13,6 +13,8 @@
#include "ggml-opencl.h" #include "ggml-opencl.h"
#endif #endif
#include "utils.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <cstring> #include <cstring>
@ -729,6 +731,7 @@ struct rwkv_context {
float * logits_out = 0; //stores address of output logit buffer float * logits_out = 0; //stores address of output logit buffer
size_t gpu_layers; size_t gpu_layers;
std::vector<uint8_t> work_buffer;
}; };
// https://stackoverflow.com/a/6458689 // https://stackoverflow.com/a/6458689
@ -1627,7 +1630,7 @@ bool rwkv_eval(struct rwkv_context * ctx, const int n_threads, const uint32_t to
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs; ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
} }
ggml_graph_compute_with_ctx(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get(),n_threads); kcpp_graph_compute_helper(ctx->serial_graph.cgraph.get(),n_threads);
rwkv_get_outputs(ctx, state_out, logits_out); rwkv_get_outputs(ctx, state_out, logits_out);
return true; return true;
@ -1715,7 +1718,7 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const ui
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs; ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
} }
ggml_graph_compute_with_ctx(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get(),n_threads); kcpp_graph_compute_helper(ctx->sequence_graph.cgraph.get(),n_threads);
rwkv_get_outputs(ctx, state_out, logits_out); rwkv_get_outputs(ctx, state_out, logits_out);
} }

View file

@ -221,4 +221,16 @@ bool should_transpose_layer(std::string name)
return true; return true;
} }
return false; return false;
}
static std::vector<uint8_t> kcpp_compute_buf;
void kcpp_graph_compute_helper(ggml_cgraph *graph, int n_threads)
{
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0)
{
kcpp_compute_buf.resize(plan.work_size);
plan.work_data = kcpp_compute_buf.data();
}
ggml_graph_compute(graph, &plan);
} }

View file

@ -54,4 +54,6 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
bool should_transpose_layer(std::string name); bool should_transpose_layer(std::string name);
void kcpp_graph_compute_helper(ggml_cgraph * graph, int n_threads);