metal : add mul_mat_id BF16 support
ggml-ci
This commit is contained in:
parent
6109cf151e
commit
c915d0add5
3 changed files with 39 additions and 26 deletions
|
@ -147,10 +147,10 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
|
||||||
|
@ -175,10 +175,11 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
||||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
||||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
|
||||||
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
|
||||||
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
|
||||||
|
@ -222,6 +223,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
||||||
|
@ -310,6 +312,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||||
|
@ -654,10 +657,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, support_simdgroup_reduction);
|
||||||
|
@ -678,10 +681,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, support_simdgroup_reduction);
|
||||||
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, support_simdgroup_reduction);
|
||||||
|
@ -725,6 +729,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, support_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, support_simdgroup_mm);
|
||||||
|
@ -813,6 +818,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||||
|
@ -902,17 +908,6 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
|
||||||
for (size_t i = 0, n = 3; i < n; ++i) {
|
|
||||||
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16 &&
|
|
||||||
op->op != GGML_OP_GET_ROWS &&
|
|
||||||
op->op != GGML_OP_MUL_MAT &&
|
|
||||||
op->op != GGML_OP_VIEW &&
|
|
||||||
op->op != GGML_OP_CPY) {
|
|
||||||
GGML_LOG_ERROR("unsupported BF16 op = %s, src[%zu] = %s\n", ggml_op_name(op->op), i, ggml_type_name(op->src[i]->type));
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
const bool support_simdgroup_mm = ctx_dev->support_simdgroup_mm;
|
||||||
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
|
||||||
|
|
||||||
|
@ -1002,10 +997,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_BF16:
|
|
||||||
switch (op->type) {
|
switch (op->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
switch (op->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
@ -2203,12 +2204,12 @@ static void ggml_metal_encode_node(
|
||||||
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||||
dst_rows > dst_rows_min) {
|
dst_rows > dst_rows_min) {
|
||||||
|
|
||||||
// some Metal matrix data types require aligned pointers
|
// some Metal matrix data types require aligned pointers
|
||||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||||
|
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2217,6 +2218,7 @@ static void ggml_metal_encode_node(
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
||||||
|
@ -2286,6 +2288,13 @@ static void ggml_metal_encode_node(
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
nth0 = 32;
|
||||||
|
nth1 = 1;
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
|
@ -3305,6 +3314,7 @@ static void ggml_metal_encode_node(
|
||||||
{
|
{
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
|
||||||
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -3576,12 +3576,13 @@ kernel void kernel_cpy(
|
||||||
|
|
||||||
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
||||||
|
|
||||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
||||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
||||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
||||||
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
||||||
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
||||||
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
|
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
|
||||||
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f32_q8_0(
|
kernel void kernel_cpy_f32_q8_0(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
@ -6547,6 +6548,7 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
||||||
|
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
||||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
||||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
||||||
|
@ -6770,6 +6772,7 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
||||||
|
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
|
|
|
@ -3599,7 +3599,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
for (int n_mats : {4}) {
|
for (int n_mats : {4}) {
|
||||||
for (int n_used : {2}) {
|
for (int n_used : {2}) {
|
||||||
for (bool b : {false}) {
|
for (bool b : {false}) {
|
||||||
for (int n : {1}) {
|
for (int n : {1, 32}) {
|
||||||
int m = 512;
|
int m = 512;
|
||||||
int k = 256;
|
int k = 256;
|
||||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue