diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index ee1f8f1bf..5fb2cb603 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -78,7 +78,7 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.seed = 1234; - ctx_params.n_ctx = n_kv_req; + ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_len, n_parallel); ctx_params.n_seq_max = n_parallel; ctx_params.n_threads = params.n_threads; diff --git a/ggml-metal.m b/ggml-metal.m index 109e5fe6b..e9598ddff 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1383,7 +1383,7 @@ static enum ggml_status ggml_metal_graph_compute( !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && + ne00 % 32 == 0 && ne00 >= 128 && (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); @@ -1698,7 +1698,7 @@ static enum ggml_status ggml_metal_graph_compute( // indirect matrix multiplication // !!! if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne20 % 32 == 0 && ne20 >= 64 && + ne20 % 32 == 0 && ne20 >= 128 && ne11 > ne11_mm_min) { id pipeline = nil; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c2916c3e4..1998e1cbc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2091,6 +2091,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1})); + for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (int n_mats : {2, 4, 8}) {