Fix MUL_MAT_ID matrix matrix shader

This commit is contained in:
0cc4m 2024-05-29 22:42:07 +02:00
parent b4abdbb881
commit 45928e8d21
3 changed files with 38022 additions and 38051 deletions

File diff suppressed because it is too large Load diff

View file

@ -2790,7 +2790,7 @@ static void ggml_vk_matmul(
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) { uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl; std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl;
#endif #endif
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
if (split_k == 1) { if (split_k == 1) {
@ -2818,12 +2818,15 @@ static void ggml_vk_matmul_id(
uint32_t expert_stride_a, uint32_t expert_stride_d, uint32_t expert_stride_a, uint32_t expert_stride_d,
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl; std::cerr << "ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", " <<
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
"expert_stride_a: " << expert_stride_a << ", expert_stride_d: " << expert_stride_d << ", n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")" << std::endl;
#endif #endif
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d, const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d,
expert_stride_a, expert_stride_d, n_as, nei0, nei1, nbi1, ne11 }; expert_stride_a, expert_stride_d, n_as, nei0, nei1, nbi1, ne11 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, n, n_as }); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
} }
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@ -3363,10 +3366,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", backend=" << ids->backend << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
#endif #endif
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -3383,12 +3386,15 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
const uint64_t nei0 = ids->ne[0]; const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1]; const uint64_t nei1 = ids->ne[1];
GGML_ASSERT(nei0 * nei1 <= 2048);
const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2]; const uint32_t nbi2 = ids->nb[2];
const uint64_t ne20 = dst->ne[0]; const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1]; const uint64_t ne21 = dst->ne[1];
const uint64_t ne22 = dst->ne[2];
const uint64_t ne23 = dst->ne[3];
const uint64_t r2 = ne12 / ne02; const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03; const uint64_t r3 = ne13 / ne03;
@ -3439,14 +3445,14 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
// Not implemented // Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const int x_ne = ne01 * ne00; const uint64_t x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10; const uint64_t y_ne = ne11 * ne10;
const int d_ne = ne21 * ne20; const uint64_t d_ne = ne21 * ne20;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne21)); const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
const bool aligned = ne10 == kpad && ne01 > 8 && ne21 > 8; const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne21, aligned); vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@ -3545,7 +3551,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
ggml_vk_matmul_id( ggml_vk_matmul_id(
ctx, subctx, pipeline, ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { d_ids, ids_buf_offset, ids_sz }, { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01, ne02, ne12, r2, r3, ne01, ne21, ne10, ne10, ne10, ne01, ne02, ne12, r2, r3,
stride_batch_x, stride_batch_y, ne20*ne21, stride_batch_x, stride_batch_y, ne20*ne21,
x_ne, ne20, x_ne, ne20,
@ -3555,10 +3561,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
#ifdef GGML_VULKAN_DEBUG #ifdef GGML_VULKAN_DEBUG
std::cerr << "ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", backend=" << ids->backend << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
#endif #endif
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT

View file

@ -225,7 +225,7 @@ mulmat_head = """#version 450
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#define EXPERT_COUNT 8 #define EXPERT_COUNT 8
#endif #endif
@ -294,7 +294,7 @@ shared FLOAT_TYPE buf_a[BM * (BK+1)];
shared FLOAT_TYPE buf_b[BN * (BK+1)]; shared FLOAT_TYPE buf_b[BN * (BK+1)];
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
shared u8vec2 row_ids[2048]; shared u16vec2 row_ids[2048];
#endif #endif
void main() { void main() {
@ -342,7 +342,7 @@ void main() {
for (uint ii1 = 0; ii1 < p.nei1; ii1++) { for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) { for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u8vec2(ii0, ii1); row_ids[_ne1] = u16vec2(ii0, ii1);
_ne1++; _ne1++;
} }
} }
@ -616,7 +616,7 @@ mulmat_body2 = """
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8 #if LOAD_VEC_B == 8
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const u8vec2 row_idx = row_ids[ic * BN + loadc_b + l]; const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@ -632,7 +632,7 @@ mulmat_body2 = """
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4 #elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const u8vec2 row_idx = row_ids[ic * BN + loadc_b + l]; const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@ -651,7 +651,7 @@ mulmat_body2 = """
#else #else
const uint row_i = ic * BN + loadc_b + l; const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) { if (row_i < _ne1) {
const u8vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
@ -705,10 +705,10 @@ mulmat_body2 = """
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
[[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const uint row_i = ic * BN + dc_warp + cc; const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break; if (row_i >= _ne1) break;
const u8vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
#endif #endif
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID