basic implementation

This commit is contained in:
Reinforce-II 2024-05-23 01:17:36 +08:00
parent d6ef0e77dd
commit 3047229758

175
ggml.c
View file

@ -26,6 +26,12 @@
#include <signal.h>
#if defined(__gnu_linux__)
#include <syscall.h>
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18
#endif
#endif
#ifdef GGML_USE_METAL
@ -36,6 +42,12 @@
#undef GGML_USE_LLAMAFILE
#endif
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
#undef GGML_USE_LLAMAFILE
#define AMX_TILE_MN 16
#define AMX_TILE_K 16
#endif
#ifdef GGML_USE_LLAMAFILE
#include "sgemm.h"
#endif
@ -1834,7 +1846,84 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
*s = sumf;
}
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
static inline void ggml_transpose_8x8xpack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs) {
__m256 row0 = _mm256_loadu_ps((const float *)((const int8_t *)s + 0*bs));
__m256 row1 = _mm256_loadu_ps((const float *)((const int8_t *)s + 1*bs));
__m256 row2 = _mm256_loadu_ps((const float *)((const int8_t *)s + 2*bs));
__m256 row3 = _mm256_loadu_ps((const float *)((const int8_t *)s + 3*bs));
__m256 row4 = _mm256_loadu_ps((const float *)((const int8_t *)s + 4*bs));
__m256 row5 = _mm256_loadu_ps((const float *)((const int8_t *)s + 5*bs));
__m256 row6 = _mm256_loadu_ps((const float *)((const int8_t *)s + 6*bs));
__m256 row7 = _mm256_loadu_ps((const float *)((const int8_t *)s + 7*bs));
__m256 tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7;
__m256 tr8, tr9, tr10, tr11, tr12, tr13, tr14, tr15;
tr0 = _mm256_unpacklo_ps(row0, row1);
tr1 = _mm256_unpackhi_ps(row0, row1);
tr2 = _mm256_unpacklo_ps(row2, row3);
tr3 = _mm256_unpackhi_ps(row2, row3);
tr4 = _mm256_unpacklo_ps(row4, row5);
tr5 = _mm256_unpackhi_ps(row4, row5);
tr6 = _mm256_unpacklo_ps(row6, row7);
tr7 = _mm256_unpackhi_ps(row6, row7);
tr8 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(1, 0, 1, 0));
tr9 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(3, 2, 3, 2));
tr10 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(1, 0, 1, 0));
tr11 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(3, 2, 3, 2));
tr12 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(1, 0, 1, 0));
tr13 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(3, 2, 3, 2));
tr14 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(1, 0, 1, 0));
tr15 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(3, 2, 3, 2));
row0 = _mm256_permute2f128_ps(tr8, tr12, 0x20);
row1 = _mm256_permute2f128_ps(tr9, tr13, 0x20);
row2 = _mm256_permute2f128_ps(tr10, tr14, 0x20);
row3 = _mm256_permute2f128_ps(tr11, tr15, 0x20);
row4 = _mm256_permute2f128_ps(tr8, tr12, 0x31);
row5 = _mm256_permute2f128_ps(tr9, tr13, 0x31);
row6 = _mm256_permute2f128_ps(tr10, tr14, 0x31);
row7 = _mm256_permute2f128_ps(tr11, tr15, 0x31);
_mm256_storeu_ps((float *)((int8_t *)d + 0*bd), row0);
_mm256_storeu_ps((float *)((int8_t *)d + 1*bd), row1);
_mm256_storeu_ps((float *)((int8_t *)d + 2*bd), row2);
_mm256_storeu_ps((float *)((int8_t *)d + 3*bd), row3);
_mm256_storeu_ps((float *)((int8_t *)d + 4*bd), row4);
_mm256_storeu_ps((float *)((int8_t *)d + 5*bd), row5);
_mm256_storeu_ps((float *)((int8_t *)d + 6*bd), row6);
_mm256_storeu_ps((float *)((int8_t *)d + 7*bd), row7);
}
static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs, const size_t nr, const size_t nc) {
assert(nr % 8 == 0);
assert(nc % 8 == 0);
for (size_t bi = 0; bi < nr; bi += 8) {
for (size_t bj = 0; bj < nc; bj += 8) {
ggml_transpose_8x8xpack4((int8_t *)d + bj*bd + bi*4, bd, (const int8_t *)s + bi*bs + bj*4, bs);
}
}
}
#endif
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (nrc == AMX_TILE_MN) {
assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0);
__tile1024i tileyt = {AMX_TILE_MN, AMX_TILE_K*4};
__tile1024i tilext = {AMX_TILE_K, AMX_TILE_MN*4};
__tile1024i tilezt = {AMX_TILE_MN, AMX_TILE_MN*sizeof(float)};
__tile_zero(&tilezt);
for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) {
ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)];
ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K);
__tile_loadd(&tileyt, y + i, by);
__tile_loadd(&tilext, axt, AMX_TILE_MN*4);
__tile_dpbf16ps(&tilezt, tileyt, tilext);
}
__tile_stored(s, bs*sizeof(float), tilezt);
return;
}
#endif
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
@ -12249,6 +12338,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
const bool src1_cont = ggml_is_contiguous(src1);
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
ggml_to_float_t const to_float = type_traits[type].to_float;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
// broadcast factors
@ -12275,8 +12365,25 @@ static void ggml_compute_forward_mul_mat_one_chunk(
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
// attempt to reduce false-sharing (does not seem to make a difference)
#ifdef __ARM_FEATURE_MATMUL_INT8
// 16 * 2, accounting for mmla kernels
float tmp[32];
#elif defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
assert(AMX_TILE_MN <= blck_0 && AMX_TILE_MN <= blck_1);
assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0);
assert(src1->type == GGML_TYPE_F32);
}
// 16 * AMX_TILE_MN, accounting for amx kernels
float tmp[16*AMX_TILE_MN];
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase);
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
float * xf32 = (float *) (wbase + 2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
xf32 = (float *) (((size_t)xf32 + 4095) & ~4095);
#else
float tmp[16];
#endif
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
@ -12293,24 +12400,39 @@ static void ggml_compute_forward_mul_mat_one_chunk(
const int64_t i2 = i12;
const int64_t i3 = i13;
const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
const uint8_t * src1_col = (const uint8_t *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13;
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
for (int cn = 0; cn < AMX_TILE_MN; ++cn) {
to_float((const uint8_t *)src0_row + ir0*nb01 + cn*nb01, xf32, ne00);
ggml_fp32_to_bf16_row(xf32, xbf16 + cn*ne00, ne00);
ggml_fp32_to_bf16_row((const float *)(src1_col + cn*nb11), ybf16 + cn*ne00, ne00);
}
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], 16, xbf16, ne00*sizeof(ggml_bf16_t), ybf16, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
}
}
else
#endif
if (true) {
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char*)wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char*)wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
}
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
@ -12473,6 +12595,11 @@ UseGgmlGemm1:;
}
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
atomic_store(&state->shared->current_chunk, nth);
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (ne01 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) {
return;
}
#endif
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@ -12540,6 +12667,11 @@ UseGgmlGemm2:;
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
num_rows_per_vec_dot = 1;
}
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (nr0 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) {
num_rows_per_vec_dot = AMX_TILE_MN;
}
#endif
// Now select a reasonable chunk size.
int chunk_size = 16;
@ -19328,6 +19460,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
set_numa_thread_affinity(state->ith);
#if defined(__gnu_linux__)
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
#endif
#endif
int node_n = -1;
int task_phase = GGML_TASK_TYPE_FINALIZE;
@ -19525,6 +19663,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
* node->src[1]->ne[2]*node->src[1]->ne[3];
}
} else
#endif
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) {
cur = n_threads*(2*AMX_TILE_MN*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
} else
#endif
if (node->src[1]->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));