From a0aae8d671f3aa704dd2faf7dcd1e0a689ad73c7 Mon Sep 17 00:00:00 2001 From: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:59:18 +0200 Subject: [PATCH] Add (broken) placeholder graph builder for RWKV --- src/llama.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/llama.cpp b/src/llama.cpp index 195abba77..ce2f87ef9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14718,6 +14718,22 @@ struct llm_build_context { return gf; } + + ggml_cgraph * build_rwkv() { + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // Input embeddings, start of the model after tokenizing ({n_embd, n_tokens}) + ggml_tensor *input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // Dummy operation, just to copy, we're not doing anything with it right now + ggml_tensor *output = ggml_scale(ctx0, input_embeddings, 1.0); + + // Mark the output as being the result + cb(output, "result_output", -1); + ggml_build_forward_expand(gf, output); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -14964,6 +14980,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_exaone(); } break; + case LLM_ARCH_RWKV: + { + result = llm.build_rwkv(); + } break; default: GGML_ABORT("fatal error"); }