Add (broken) placeholder graph builder for RWKV

This commit is contained in:
Layl Bongers 2024-04-17 14:59:18 +02:00 committed by Molly Sophia
parent e92c74f4a1
commit a0aae8d671

View file

@ -14718,6 +14718,22 @@ struct llm_build_context {
return gf; return gf;
} }
ggml_cgraph * build_rwkv() {
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
// Input embeddings, start of the model after tokenizing ({n_embd, n_tokens})
ggml_tensor *input_embeddings = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// Dummy operation, just to copy, we're not doing anything with it right now
ggml_tensor *output = ggml_scale(ctx0, input_embeddings, 1.0);
// Mark the output as being the result
cb(output, "result_output", -1);
ggml_build_forward_expand(gf, output);
return gf;
}
}; };
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) { static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@ -14964,6 +14980,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_exaone(); result = llm.build_exaone();
} break; } break;
case LLM_ARCH_RWKV:
{
result = llm.build_rwkv();
} break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }