Fix MUL_MAT_ID matrix matrix shader
This commit is contained in:
parent
b4abdbb881
commit
45928e8d21
3 changed files with 38022 additions and 38051 deletions
76015
ggml-vulkan-shaders.hpp
76015
ggml-vulkan-shaders.hpp
File diff suppressed because it is too large
Load diff
|
@ -2790,7 +2790,7 @@ static void ggml_vk_matmul(
|
||||||
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
||||||
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
|
||||||
#ifdef GGML_VULKAN_DEBUG
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl;
|
std::cerr << "ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
if (split_k == 1) {
|
if (split_k == 1) {
|
||||||
|
@ -2818,12 +2818,15 @@ static void ggml_vk_matmul_id(
|
||||||
uint32_t expert_stride_a, uint32_t expert_stride_d,
|
uint32_t expert_stride_a, uint32_t expert_stride_d,
|
||||||
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
|
uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
|
||||||
#ifdef GGML_VULKAN_DEBUG
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
std::cerr << "ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), c: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ")" << std::endl;
|
std::cerr << "ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
|
||||||
|
"m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", " <<
|
||||||
|
"batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
|
||||||
|
"expert_stride_a: " << expert_stride_a << ", expert_stride_d: " << expert_stride_d << ", n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d,
|
const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, k, ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d,
|
||||||
expert_stride_a, expert_stride_d, n_as, nei0, nei1, nbi1, ne11 };
|
expert_stride_a, expert_stride_d, n_as, nei0, nei1, nbi1, ne11 };
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, n, n_as });
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
|
||||||
|
@ -3363,10 +3366,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||||
#ifdef GGML_VULKAN_DEBUG
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
std::cerr << "ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||||
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
||||||
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", backend=" << ids->backend << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
||||||
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
||||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||||
|
@ -3383,12 +3386,15 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
|
||||||
|
|
||||||
const uint64_t nei0 = ids->ne[0];
|
const uint64_t nei0 = ids->ne[0];
|
||||||
const uint64_t nei1 = ids->ne[1];
|
const uint64_t nei1 = ids->ne[1];
|
||||||
|
GGML_ASSERT(nei0 * nei1 <= 2048);
|
||||||
|
|
||||||
const uint32_t nbi1 = ids->nb[1];
|
const uint32_t nbi1 = ids->nb[1];
|
||||||
const uint32_t nbi2 = ids->nb[2];
|
const uint32_t nbi2 = ids->nb[2];
|
||||||
|
|
||||||
const uint64_t ne20 = dst->ne[0];
|
const uint64_t ne20 = dst->ne[0];
|
||||||
const uint64_t ne21 = dst->ne[1];
|
const uint64_t ne21 = dst->ne[1];
|
||||||
|
const uint64_t ne22 = dst->ne[2];
|
||||||
|
const uint64_t ne23 = dst->ne[3];
|
||||||
|
|
||||||
const uint64_t r2 = ne12 / ne02;
|
const uint64_t r2 = ne12 / ne02;
|
||||||
const uint64_t r3 = ne13 / ne03;
|
const uint64_t r3 = ne13 / ne03;
|
||||||
|
@ -3439,14 +3445,14 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
const int x_ne = ne01 * ne00;
|
const uint64_t x_ne = ne01 * ne00;
|
||||||
const int y_ne = ne11 * ne10;
|
const uint64_t y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne21 * ne20;
|
const uint64_t d_ne = ne21 * ne20;
|
||||||
|
|
||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne21));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && ne21 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne21, aligned);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||||
|
@ -3545,7 +3551,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
|
||||||
ggml_vk_matmul_id(
|
ggml_vk_matmul_id(
|
||||||
ctx, subctx, pipeline,
|
ctx, subctx, pipeline,
|
||||||
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
|
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
|
||||||
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { d_ids, ids_buf_offset, ids_sz },
|
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
|
||||||
ne01, ne21, ne10, ne10, ne10, ne01, ne02, ne12, r2, r3,
|
ne01, ne21, ne10, ne10, ne10, ne01, ne02, ne12, r2, r3,
|
||||||
stride_batch_x, stride_batch_y, ne20*ne21,
|
stride_batch_x, stride_batch_y, ne20*ne21,
|
||||||
x_ne, ne20,
|
x_ne, ne20,
|
||||||
|
@ -3555,10 +3561,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context *
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||||
#ifdef GGML_VULKAN_DEBUG
|
#ifdef GGML_VULKAN_DEBUG
|
||||||
std::cerr << "ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
std::cerr << "ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||||
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
||||||
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", backend=" << ids->backend << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
|
||||||
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
|
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)" << std::endl;
|
||||||
#endif
|
#endif
|
||||||
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
|
||||||
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
|
||||||
|
|
|
@ -225,7 +225,7 @@ mulmat_head = """#version 450
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||||
|
|
||||||
#define EXPERT_COUNT 8
|
#define EXPERT_COUNT 8
|
||||||
#endif
|
#endif
|
||||||
|
@ -294,7 +294,7 @@ shared FLOAT_TYPE buf_a[BM * (BK+1)];
|
||||||
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
shared FLOAT_TYPE buf_b[BN * (BK+1)];
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
shared u8vec2 row_ids[2048];
|
shared u16vec2 row_ids[2048];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
@ -342,7 +342,7 @@ void main() {
|
||||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||||
row_ids[_ne1] = u8vec2(ii0, ii1);
|
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||||
_ne1++;
|
_ne1++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -616,7 +616,7 @@ mulmat_body2 = """
|
||||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||||
#if LOAD_VEC_B == 8
|
#if LOAD_VEC_B == 8
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
const u8vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
#else
|
#else
|
||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
|
@ -632,7 +632,7 @@ mulmat_body2 = """
|
||||||
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||||
#elif LOAD_VEC_B == 4
|
#elif LOAD_VEC_B == 4
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
const u8vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
#else
|
#else
|
||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
|
@ -651,7 +651,7 @@ mulmat_body2 = """
|
||||||
#else
|
#else
|
||||||
const uint row_i = ic * BN + loadc_b + l;
|
const uint row_i = ic * BN + loadc_b + l;
|
||||||
if (row_i < _ne1) {
|
if (row_i < _ne1) {
|
||||||
const u8vec2 row_idx = row_ids[row_i];
|
const u16vec2 row_idx = row_ids[row_i];
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||||
} else {
|
} else {
|
||||||
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
|
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
|
||||||
|
@ -705,10 +705,10 @@ mulmat_body2 = """
|
||||||
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
const uint row_i = ic * BN + dc_warp + cc;
|
const uint row_i = dc_warp + cc;
|
||||||
if (row_i >= _ne1) break;
|
if (row_i >= _ne1) break;
|
||||||
|
|
||||||
const u8vec2 row_idx = row_ids[row_i];
|
const u16vec2 row_idx = row_ids[row_i];
|
||||||
#endif
|
#endif
|
||||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue