llama: add support for QRWKV6 model architecture (#11001)
llama: add support for QRWKV6 model architecture (#11001) * WIP: Add support for RWKV6Qwen2 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV: Some graph simplification Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add support for RWKV6Qwen2 with cpu and cuda GLA Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix some typos Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix wkv test & add gla test Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix cuda warning Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update README.md Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update ggml/src/ggml-cuda/gla.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Fix fused lerp weights loading with RWKV6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * better sanity check skipping for QRWKV6 in llama-quant thanks @compilade Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: compilade <git@compilade.net> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
parent
c6860cc734
commit
ee7136c6d1
23 changed files with 862 additions and 124 deletions
|
@ -37,6 +37,7 @@
|
|||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
|
@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue