fixed rwkv, standardized new ctx usage
This commit is contained in:
parent
2827920044
commit
523fc3be52
8 changed files with 27 additions and 8 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -68,3 +68,5 @@ koboldcpp_openblas.dll
|
|||
koboldcpp_openblas_noavx2.dll
|
||||
koboldcpp_clblast.dll
|
||||
koboldcpp_cublas.dll
|
||||
cublas64_11.dll
|
||||
cublasLt64_11.dll
|
|
@ -707,7 +707,7 @@ bool gpt2_eval(
|
|||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
kcpp_graph_compute_helper(&gf, n_threads);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
|
|
|
@ -619,7 +619,7 @@ bool gptj_eval(
|
|||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
kcpp_graph_compute_helper(&gf, n_threads);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
|
|
|
@ -542,7 +542,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
|||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
kcpp_graph_compute_helper(&gf, n_threads);
|
||||
|
||||
// std::cout << "Qcur" << std::endl;
|
||||
// print_tensor(Qcur);
|
||||
|
|
|
@ -638,7 +638,7 @@ bool gpt_neox_eval(
|
|||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
kcpp_graph_compute_helper(&gf, n_threads);
|
||||
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_print (&gf);
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
@ -729,6 +731,7 @@ struct rwkv_context {
|
|||
float * logits_out = 0; //stores address of output logit buffer
|
||||
|
||||
size_t gpu_layers;
|
||||
std::vector<uint8_t> work_buffer;
|
||||
};
|
||||
|
||||
// https://stackoverflow.com/a/6458689
|
||||
|
@ -1627,7 +1630,7 @@ bool rwkv_eval(struct rwkv_context * ctx, const int n_threads, const uint32_t to
|
|||
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
|
||||
}
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get(),n_threads);
|
||||
kcpp_graph_compute_helper(ctx->serial_graph.cgraph.get(),n_threads);
|
||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||
|
||||
return true;
|
||||
|
@ -1715,7 +1718,7 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const ui
|
|||
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
|
||||
}
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get(),n_threads);
|
||||
kcpp_graph_compute_helper(ctx->sequence_graph.cgraph.get(),n_threads);
|
||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||
}
|
||||
|
||||
|
|
|
@ -222,3 +222,15 @@ bool should_transpose_layer(std::string name)
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::vector<uint8_t> kcpp_compute_buf;
|
||||
void kcpp_graph_compute_helper(ggml_cgraph *graph, int n_threads)
|
||||
{
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
if (plan.work_size > 0)
|
||||
{
|
||||
kcpp_compute_buf.resize(plan.work_size);
|
||||
plan.work_data = kcpp_compute_buf.data();
|
||||
}
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
|
@ -55,3 +55,5 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
|
|||
|
||||
|
||||
bool should_transpose_layer(std::string name);
|
||||
|
||||
void kcpp_graph_compute_helper(ggml_cgraph * graph, int n_threads);
|
Loading…
Add table
Add a link
Reference in a new issue