fixed rwkv, standardized new ctx usage
This commit is contained in:
parent
2827920044
commit
523fc3be52
8 changed files with 27 additions and 8 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -67,4 +67,6 @@ koboldcpp_failsafe.dll
|
||||||
koboldcpp_openblas.dll
|
koboldcpp_openblas.dll
|
||||||
koboldcpp_openblas_noavx2.dll
|
koboldcpp_openblas_noavx2.dll
|
||||||
koboldcpp_clblast.dll
|
koboldcpp_clblast.dll
|
||||||
koboldcpp_cublas.dll
|
koboldcpp_cublas.dll
|
||||||
|
cublas64_11.dll
|
||||||
|
cublasLt64_11.dll
|
|
@ -707,7 +707,7 @@ bool gpt2_eval(
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
kcpp_graph_compute_helper(&gf, n_threads);
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_print (&gf);
|
// ggml_graph_print (&gf);
|
||||||
|
|
|
@ -619,7 +619,7 @@ bool gptj_eval(
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
kcpp_graph_compute_helper(&gf, n_threads);
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_print (&gf);
|
// ggml_graph_print (&gf);
|
||||||
|
|
|
@ -542,7 +542,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
kcpp_graph_compute_helper(&gf, n_threads);
|
||||||
|
|
||||||
// std::cout << "Qcur" << std::endl;
|
// std::cout << "Qcur" << std::endl;
|
||||||
// print_tensor(Qcur);
|
// print_tensor(Qcur);
|
||||||
|
|
|
@ -638,7 +638,7 @@ bool gpt_neox_eval(
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, inpL);
|
ggml_build_forward_expand(&gf, inpL);
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
kcpp_graph_compute_helper(&gf, n_threads);
|
||||||
|
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_print (&gf);
|
// ggml_graph_print (&gf);
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
#include "ggml-opencl.h"
|
#include "ggml-opencl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
@ -729,6 +731,7 @@ struct rwkv_context {
|
||||||
float * logits_out = 0; //stores address of output logit buffer
|
float * logits_out = 0; //stores address of output logit buffer
|
||||||
|
|
||||||
size_t gpu_layers;
|
size_t gpu_layers;
|
||||||
|
std::vector<uint8_t> work_buffer;
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://stackoverflow.com/a/6458689
|
// https://stackoverflow.com/a/6458689
|
||||||
|
@ -1627,7 +1630,7 @@ bool rwkv_eval(struct rwkv_context * ctx, const int n_threads, const uint32_t to
|
||||||
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
|
ctx->serial_graph.cgraph->n_leafs = ctx->serial_graph.post_logits_leafs;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx->serial_graph.ctx.ctx, ctx->serial_graph.cgraph.get(),n_threads);
|
kcpp_graph_compute_helper(ctx->serial_graph.cgraph.get(),n_threads);
|
||||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -1715,7 +1718,7 @@ bool rwkv_eval_sequence(struct rwkv_context * ctx, const int n_threads, const ui
|
||||||
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
|
ctx->sequence_graph.cgraph->n_leafs = ctx->sequence_graph.post_logits_leafs;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx->sequence_graph.ctx.ctx, ctx->sequence_graph.cgraph.get(),n_threads);
|
kcpp_graph_compute_helper(ctx->sequence_graph.cgraph.get(),n_threads);
|
||||||
rwkv_get_outputs(ctx, state_out, logits_out);
|
rwkv_get_outputs(ctx, state_out, logits_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -221,4 +221,16 @@ bool should_transpose_layer(std::string name)
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<uint8_t> kcpp_compute_buf;
|
||||||
|
void kcpp_graph_compute_helper(ggml_cgraph *graph, int n_threads)
|
||||||
|
{
|
||||||
|
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||||
|
if (plan.work_size > 0)
|
||||||
|
{
|
||||||
|
kcpp_compute_buf.resize(plan.work_size);
|
||||||
|
plan.work_data = kcpp_compute_buf.data();
|
||||||
|
}
|
||||||
|
ggml_graph_compute(graph, &plan);
|
||||||
}
|
}
|
|
@ -54,4 +54,6 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool should_transpose_layer(std::string name);
|
bool should_transpose_layer(std::string name);
|
||||||
|
|
||||||
|
void kcpp_graph_compute_helper(ggml_cgraph * graph, int n_threads);
|
Loading…
Add table
Add a link
Reference in a new issue