From 3047229758f6e8306f5f282ef1c7aaecb054a2fa Mon Sep 17 00:00:00 2001 From: Reinforce-II Date: Thu, 23 May 2024 01:17:36 +0800 Subject: [PATCH] basic implementation --- ggml.c | 175 +++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 159 insertions(+), 16 deletions(-) diff --git a/ggml.c b/ggml.c index 5145ceec9..52192c782 100644 --- a/ggml.c +++ b/ggml.c @@ -26,6 +26,12 @@ #include #if defined(__gnu_linux__) #include +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#endif #endif #ifdef GGML_USE_METAL @@ -36,6 +42,12 @@ #undef GGML_USE_LLAMAFILE #endif +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +#undef GGML_USE_LLAMAFILE +#define AMX_TILE_MN 16 +#define AMX_TILE_K 16 +#endif + #ifdef GGML_USE_LLAMAFILE #include "sgemm.h" #endif @@ -1834,7 +1846,84 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * *s = sumf; } +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +static inline void ggml_transpose_8x8xpack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs) { + __m256 row0 = _mm256_loadu_ps((const float *)((const int8_t *)s + 0*bs)); + __m256 row1 = _mm256_loadu_ps((const float *)((const int8_t *)s + 1*bs)); + __m256 row2 = _mm256_loadu_ps((const float *)((const int8_t *)s + 2*bs)); + __m256 row3 = _mm256_loadu_ps((const float *)((const int8_t *)s + 3*bs)); + __m256 row4 = _mm256_loadu_ps((const float *)((const int8_t *)s + 4*bs)); + __m256 row5 = _mm256_loadu_ps((const float *)((const int8_t *)s + 5*bs)); + __m256 row6 = _mm256_loadu_ps((const float *)((const int8_t *)s + 6*bs)); + __m256 row7 = _mm256_loadu_ps((const float *)((const int8_t *)s + 7*bs)); + + __m256 tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7; + __m256 tr8, tr9, tr10, tr11, tr12, tr13, tr14, tr15; + tr0 = _mm256_unpacklo_ps(row0, row1); + tr1 = _mm256_unpackhi_ps(row0, row1); + tr2 = _mm256_unpacklo_ps(row2, row3); + tr3 = _mm256_unpackhi_ps(row2, row3); + tr4 = _mm256_unpacklo_ps(row4, row5); + tr5 = _mm256_unpackhi_ps(row4, row5); + tr6 = _mm256_unpacklo_ps(row6, row7); + tr7 = _mm256_unpackhi_ps(row6, row7); + tr8 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(1, 0, 1, 0)); + tr9 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(3, 2, 3, 2)); + tr10 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(1, 0, 1, 0)); + tr11 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(3, 2, 3, 2)); + tr12 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(1, 0, 1, 0)); + tr13 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(3, 2, 3, 2)); + tr14 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(1, 0, 1, 0)); + tr15 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(3, 2, 3, 2)); + row0 = _mm256_permute2f128_ps(tr8, tr12, 0x20); + row1 = _mm256_permute2f128_ps(tr9, tr13, 0x20); + row2 = _mm256_permute2f128_ps(tr10, tr14, 0x20); + row3 = _mm256_permute2f128_ps(tr11, tr15, 0x20); + row4 = _mm256_permute2f128_ps(tr8, tr12, 0x31); + row5 = _mm256_permute2f128_ps(tr9, tr13, 0x31); + row6 = _mm256_permute2f128_ps(tr10, tr14, 0x31); + row7 = _mm256_permute2f128_ps(tr11, tr15, 0x31); + + _mm256_storeu_ps((float *)((int8_t *)d + 0*bd), row0); + _mm256_storeu_ps((float *)((int8_t *)d + 1*bd), row1); + _mm256_storeu_ps((float *)((int8_t *)d + 2*bd), row2); + _mm256_storeu_ps((float *)((int8_t *)d + 3*bd), row3); + _mm256_storeu_ps((float *)((int8_t *)d + 4*bd), row4); + _mm256_storeu_ps((float *)((int8_t *)d + 5*bd), row5); + _mm256_storeu_ps((float *)((int8_t *)d + 6*bd), row6); + _mm256_storeu_ps((float *)((int8_t *)d + 7*bd), row7); +} + +static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs, const size_t nr, const size_t nc) { + assert(nr % 8 == 0); + assert(nc % 8 == 0); + for (size_t bi = 0; bi < nr; bi += 8) { + for (size_t bj = 0; bj < nc; bj += 8) { + ggml_transpose_8x8xpack4((int8_t *)d + bj*bd + bi*4, bd, (const int8_t *)s + bi*bs + bj*4, bs); + } + } +} +#endif + static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (nrc == AMX_TILE_MN) { + assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0); + __tile1024i tileyt = {AMX_TILE_MN, AMX_TILE_K*4}; + __tile1024i tilext = {AMX_TILE_K, AMX_TILE_MN*4}; + __tile1024i tilezt = {AMX_TILE_MN, AMX_TILE_MN*sizeof(float)}; + __tile_zero(&tilezt); + for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) { + ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)]; + ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K); + __tile_loadd(&tileyt, y + i, by); + __tile_loadd(&tilext, axt, AMX_TILE_MN*4); + __tile_dpbf16ps(&tilezt, tileyt, tilext); + } + __tile_stored(s, bs*sizeof(float), tilezt); + return; + } +#endif assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -12249,6 +12338,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( const bool src1_cont = ggml_is_contiguous(src1); ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + ggml_to_float_t const to_float = type_traits[type].to_float; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; // broadcast factors @@ -12275,8 +12365,25 @@ static void ggml_compute_forward_mul_mat_one_chunk( const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; // attempt to reduce false-sharing (does not seem to make a difference) +#ifdef __ARM_FEATURE_MATMUL_INT8 // 16 * 2, accounting for mmla kernels float tmp[32]; +#elif defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (num_rows_per_vec_dot == AMX_TILE_MN) { + assert(AMX_TILE_MN <= blck_0 && AMX_TILE_MN <= blck_1); + assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0); + assert(src1->type == GGML_TYPE_F32); + } + // 16 * AMX_TILE_MN, accounting for amx kernels + float tmp[16*AMX_TILE_MN]; + uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096); + ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase); + ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)); + float * xf32 = (float *) (wbase + 2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)); + xf32 = (float *) (((size_t)xf32 + 4095) & ~4095); +#else + float tmp[16]; +#endif for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { @@ -12293,24 +12400,39 @@ static void ggml_compute_forward_mul_mat_one_chunk( const int64_t i2 = i12; const int64_t i3 = i13; - const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); + const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); + float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (num_rows_per_vec_dot == AMX_TILE_MN) { + const uint8_t * src1_col = (const uint8_t *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13; + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + for (int cn = 0; cn < AMX_TILE_MN; ++cn) { + to_float((const uint8_t *)src0_row + ir0*nb01 + cn*nb01, xf32, ne00); + ggml_fp32_to_bf16_row(xf32, xbf16 + cn*ne00, ne00); + ggml_fp32_to_bf16_row((const float *)(src1_col + cn*nb11), ybf16 + cn*ne00, ne00); + } + ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], 16, xbf16, ne00*sizeof(ggml_bf16_t), ybf16, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN); + } + } + else +#endif + if (true) { + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char*)wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char*)wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} - - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + } } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { @@ -12473,6 +12595,11 @@ UseGgmlGemm1:; } // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. atomic_store(&state->shared->current_chunk, nth); +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (ne01 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) { + return; + } +#endif if (src1->type != vec_dot_type) { char * wdata = params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -12540,6 +12667,11 @@ UseGgmlGemm2:; if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { num_rows_per_vec_dot = 1; } +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (nr0 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) { + num_rows_per_vec_dot = AMX_TILE_MN; + } +#endif // Now select a reasonable chunk size. int chunk_size = 16; @@ -19328,6 +19460,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { set_numa_thread_affinity(state->ith); +#if defined(__gnu_linux__) +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); +#endif +#endif + int node_n = -1; int task_phase = GGML_TASK_TYPE_FINALIZE; @@ -19525,6 +19663,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa * node->src[1]->ne[2]*node->src[1]->ne[3]; } } else +#endif +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) { + cur = n_threads*(2*AMX_TILE_MN*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096); + } else #endif if (node->src[1]->type != vec_dot_type) { cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));