Merge branch 'master' into gg/flash-attn
This commit is contained in:
commit
013721df2b
157 changed files with 19090 additions and 15488 deletions
41
ggml-metal.m
41
ggml-metal.m
|
@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||
|
@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
||||
|
@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
||||
|
@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||
|
@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
||||
|
@ -499,6 +504,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||
|
@ -526,6 +532,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
||||
|
@ -549,6 +556,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
||||
|
@ -569,6 +577,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||
|
@ -589,6 +598,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
||||
|
@ -853,6 +863,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
struct ggml_tensor * dst = gf->nodes[i];
|
||||
|
||||
if (ggml_is_empty(dst)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
|
@ -1439,6 +1453,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||
|
@ -1593,6 +1608,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
nth0 = 4;
|
||||
|
@ -1637,9 +1658,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
|
@ -1761,6 +1782,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||
|
@ -1918,6 +1940,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
nth0 = 4;
|
||||
|
@ -1978,9 +2006,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
||||
}
|
||||
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
||||
src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 ||
|
||||
src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
|
||||
src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
||||
|
@ -2042,6 +2070,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
||||
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue