add appropriate asserts

This commit is contained in:
Zhiyuan Li 2024-11-05 01:20:52 +11:00
parent b81602477b
commit 72e4432577
3 changed files with 4 additions and 3 deletions

View file

@ -11669,9 +11669,10 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
float * time_faaaa = (float *) dst->src[3]->data;
float * time_decay = (float *) dst->src[4]->data;
size_t t_stride = HEADS * head_size;
size_t t_stride = HEADS * head_size; // Same to C
size_t h_stride = C / HEADS;
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
size_t h_stride_2d = head_size * head_size;
if (params->ith == 0) {

View file

@ -83,7 +83,7 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}

View file

@ -112,7 +112,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE);
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
dpct::queue_ptr stream = ctx.stream();