add metal impl

This commit is contained in:
slaren 2024-04-12 23:33:46 +02:00
parent 137fbb8f59
commit fc363e4afc
4 changed files with 481 additions and 565 deletions

View file

@ -1683,15 +1683,10 @@ static enum ggml_status ggml_metal_graph_compute(
} break; } break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{ {
//GGML_ASSERT(ne00 == ne10);
//GGML_ASSERT(ne03 == ne13);
const int n_as = src0->ne[2]; const int n_as = src0->ne[2];
// max size of the src1ids array in the kernel shared buffer
GGML_ASSERT(ne11 <= 4096);
// src2 = ids // src2 = ids
const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20); const int64_t ne20 = src2->ne[0];
const int64_t ne21 = src2->ne[1]; const int64_t ne21 = src2->ne[1];
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@ -1712,15 +1707,13 @@ static enum ggml_status ggml_metal_graph_compute(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared // find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel // to the matrix-vector kernel
int ne11_mm_min = n_as; // ne20 = n_used_experts
// ne21 = n_rows
const int dst_rows = ne20*ne21;
const int dst_rows_min = n_as;
const int idx = ((int32_t *) dst->op_params)[0]; // max size of the rowids array in the kernel shared buffer
GGML_ASSERT(dst_rows <= 2048);
// batch size
GGML_ASSERT(ne21 == ne11); // ?
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
const uint r2 = 1;
const uint r3 = 1;
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@ -1730,7 +1723,7 @@ static enum ggml_status ggml_metal_graph_compute(
// !!! // !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00 % 32 == 0 && ne00 >= 64 && ne00 % 32 == 0 && ne00 >= 64 &&
ne11 > ne11_mm_min) { dst_rows > dst_rows_min) {
// some Metal matrix data types require aligned pointers // some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@ -1772,26 +1765,26 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13]; [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&idx length:sizeof(idx) atIndex:19]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else { } else {
int nth0 = 32; int nth0 = 32;
int nth1 = 1; int nth1 = 1;
@ -1944,72 +1937,72 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ne00 >= nth0*nth1); GGML_ASSERT(ne00 >= nth0*nth1);
} }
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4]; [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5]; [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6]; [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:21]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:22]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
[encoder setBytes:&idx length:sizeof(idx) atIndex:23];
const int64_t _ne1 = 1;
const int tgz = dst_rows;
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
const int mem_size = 32*sizeof(float); const int mem_size = 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q4_K) { else if (src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q3_K) { else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else #else
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif #endif
} }
else if (src0t == GGML_TYPE_Q5_K) { else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
const int64_t ny = (_ne1 + nrows - 1)/nrows; const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
} }
} break; } break;
@ -2714,8 +2707,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
return NULL; return NULL;
} }
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); //GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
ggml_backend_metal_log_allocated_size(device); //ggml_backend_metal_log_allocated_size(device);
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
} }

File diff suppressed because it is too large Load diff

1
ggml.c
View file

@ -11120,7 +11120,6 @@ static void ggml_compute_forward_mul_mat_id(
} }
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
} }
} }
} }

View file

@ -478,9 +478,8 @@ struct test_case {
} }
double err = nmse(f1.data(), f2.data(), f1.size()); double err = nmse(f1.data(), f2.data(), f1.size());
printf("[%s] NMSE = %.9f > %.9f \n", ggml_op_desc(t1), err, ud->max_err);
if (err > ud->max_err) { if (err > ud->max_err) {
//printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
//for (int i = 0; i < (int) f1.size(); i++) { //for (int i = 0; i < (int) f1.size(); i++) {
// printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
//} //}
@ -956,7 +955,7 @@ struct test_mul_mat_id : public test_case {
const int64_t k; const int64_t k;
std::string vars() override { std::string vars() override {
return VARS_TO_STR7(type_a, type_b, n_mats, b, m, n, k); return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
} }
double max_nmse_err() override { double max_nmse_err() override {
@ -976,13 +975,17 @@ struct test_mul_mat_id : public test_case {
int n_mats = 8, int n_used = 2, bool b = false, int n_mats = 8, int n_used = 2, bool b = false,
int64_t m = 32, int64_t n = 32, int64_t k = 32) int64_t m = 32, int64_t n = 32, int64_t k = 32)
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b), : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
m(m), n(n), k(k) {} m(m), n(n), k(k) {
GGML_ASSERT(n_used <= n_mats);
}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n) // C^T = A * B^T: (k, m) * (k, n) => (m, n)
ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats); ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n); ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
if (n_used != n_mats) {
ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0); ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
}
ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n); ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids); ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
return out; return out;
@ -1958,9 +1961,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
}; };
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
test_cases.emplace_back(new test_moe(8, 2, 32, 4096, 8*1024));
// unary ops // unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op)); test_cases.emplace_back(new test_unary((ggml_unary_op) op));
@ -2100,17 +2100,17 @@ test_cases.emplace_back(new test_moe(8, 2, 32, 4096, 8*1024));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
for (ggml_type type_a : all_types) { for (ggml_type type_a : all_types) {
//for (ggml_type type_a : {GGML_TYPE_F16}) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {2, 4, 8}) { for (int n_mats : {4, 8}) {
for (int n_used : {1, 2, 4}) {
for (bool b : {false, true}) { for (bool b : {false, true}) {
// cur shape: 4096 1 32 1 for (int n : {1, 32}) {
// ffn_up_exps shape: 4096 8192 8 1 int m = 512;//8192;
// selected_experts shape: 2 32 1 1 int k = 256;//4096;
int m = 8192; test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
int n = 32; }
int k = 4096; }
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, 2, b, m, n, k));
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, 2, b, m, 1, k));
} }
} }
} }
@ -2193,6 +2193,8 @@ test_cases.emplace_back(new test_moe(8, 2, 32, 4096, 8*1024));
test_cases.emplace_back(new test_llama(2)); test_cases.emplace_back(new test_llama(2));
test_cases.emplace_back(new test_falcon(1)); test_cases.emplace_back(new test_falcon(1));
test_cases.emplace_back(new test_falcon(2)); test_cases.emplace_back(new test_falcon(2));
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
test_cases.emplace_back(new test_moe(8, 2, 32, 4096, 8*1024));
#endif #endif
// run tests // run tests