add appropriate asserts
This commit is contained in:
parent
b81602477b
commit
72e4432577
3 changed files with 4 additions and 3 deletions
|
@ -11669,9 +11669,10 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||||
float * time_faaaa = (float *) dst->src[3]->data;
|
float * time_faaaa = (float *) dst->src[3]->data;
|
||||||
float * time_decay = (float *) dst->src[4]->data;
|
float * time_decay = (float *) dst->src[4]->data;
|
||||||
|
|
||||||
size_t t_stride = HEADS * head_size;
|
size_t t_stride = HEADS * head_size; // Same to C
|
||||||
|
|
||||||
size_t h_stride = C / HEADS;
|
size_t h_stride = C / HEADS;
|
||||||
|
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||||
size_t h_stride_2d = head_size * head_size;
|
size_t h_stride_2d = head_size * head_size;
|
||||||
|
|
||||||
if (params->ith == 0) {
|
if (params->ith == 0) {
|
||||||
|
|
|
@ -83,7 +83,7 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(C % H == 0);
|
GGML_ASSERT(C % H == 0);
|
||||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
|
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||||
|
|
||||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
|
||||||
|
|
||||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(C % H == 0);
|
GGML_ASSERT(C % H == 0);
|
||||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE);
|
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||||
|
|
||||||
dpct::queue_ptr stream = ctx.stream();
|
dpct::queue_ptr stream = ctx.stream();
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue