basic implementation
This commit is contained in:
parent
d6ef0e77dd
commit
3047229758
1 changed files with 159 additions and 16 deletions
175
ggml.c
175
ggml.c
|
@ -26,6 +26,12 @@
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#if defined(__gnu_linux__)
|
#if defined(__gnu_linux__)
|
||||||
#include <syscall.h>
|
#include <syscall.h>
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
#define ARCH_GET_XCOMP_PERM 0x1022
|
||||||
|
#define ARCH_REQ_XCOMP_PERM 0x1023
|
||||||
|
#define XFEATURE_XTILECFG 17
|
||||||
|
#define XFEATURE_XTILEDATA 18
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
|
@ -36,6 +42,12 @@
|
||||||
#undef GGML_USE_LLAMAFILE
|
#undef GGML_USE_LLAMAFILE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
#undef GGML_USE_LLAMAFILE
|
||||||
|
#define AMX_TILE_MN 16
|
||||||
|
#define AMX_TILE_K 16
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_LLAMAFILE
|
#ifdef GGML_USE_LLAMAFILE
|
||||||
#include "sgemm.h"
|
#include "sgemm.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -1834,7 +1846,84 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
static inline void ggml_transpose_8x8xpack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs) {
|
||||||
|
__m256 row0 = _mm256_loadu_ps((const float *)((const int8_t *)s + 0*bs));
|
||||||
|
__m256 row1 = _mm256_loadu_ps((const float *)((const int8_t *)s + 1*bs));
|
||||||
|
__m256 row2 = _mm256_loadu_ps((const float *)((const int8_t *)s + 2*bs));
|
||||||
|
__m256 row3 = _mm256_loadu_ps((const float *)((const int8_t *)s + 3*bs));
|
||||||
|
__m256 row4 = _mm256_loadu_ps((const float *)((const int8_t *)s + 4*bs));
|
||||||
|
__m256 row5 = _mm256_loadu_ps((const float *)((const int8_t *)s + 5*bs));
|
||||||
|
__m256 row6 = _mm256_loadu_ps((const float *)((const int8_t *)s + 6*bs));
|
||||||
|
__m256 row7 = _mm256_loadu_ps((const float *)((const int8_t *)s + 7*bs));
|
||||||
|
|
||||||
|
__m256 tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7;
|
||||||
|
__m256 tr8, tr9, tr10, tr11, tr12, tr13, tr14, tr15;
|
||||||
|
tr0 = _mm256_unpacklo_ps(row0, row1);
|
||||||
|
tr1 = _mm256_unpackhi_ps(row0, row1);
|
||||||
|
tr2 = _mm256_unpacklo_ps(row2, row3);
|
||||||
|
tr3 = _mm256_unpackhi_ps(row2, row3);
|
||||||
|
tr4 = _mm256_unpacklo_ps(row4, row5);
|
||||||
|
tr5 = _mm256_unpackhi_ps(row4, row5);
|
||||||
|
tr6 = _mm256_unpacklo_ps(row6, row7);
|
||||||
|
tr7 = _mm256_unpackhi_ps(row6, row7);
|
||||||
|
tr8 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(1, 0, 1, 0));
|
||||||
|
tr9 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(3, 2, 3, 2));
|
||||||
|
tr10 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(1, 0, 1, 0));
|
||||||
|
tr11 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(3, 2, 3, 2));
|
||||||
|
tr12 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(1, 0, 1, 0));
|
||||||
|
tr13 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(3, 2, 3, 2));
|
||||||
|
tr14 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(1, 0, 1, 0));
|
||||||
|
tr15 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(3, 2, 3, 2));
|
||||||
|
row0 = _mm256_permute2f128_ps(tr8, tr12, 0x20);
|
||||||
|
row1 = _mm256_permute2f128_ps(tr9, tr13, 0x20);
|
||||||
|
row2 = _mm256_permute2f128_ps(tr10, tr14, 0x20);
|
||||||
|
row3 = _mm256_permute2f128_ps(tr11, tr15, 0x20);
|
||||||
|
row4 = _mm256_permute2f128_ps(tr8, tr12, 0x31);
|
||||||
|
row5 = _mm256_permute2f128_ps(tr9, tr13, 0x31);
|
||||||
|
row6 = _mm256_permute2f128_ps(tr10, tr14, 0x31);
|
||||||
|
row7 = _mm256_permute2f128_ps(tr11, tr15, 0x31);
|
||||||
|
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 0*bd), row0);
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 1*bd), row1);
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 2*bd), row2);
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 3*bd), row3);
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 4*bd), row4);
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 5*bd), row5);
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 6*bd), row6);
|
||||||
|
_mm256_storeu_ps((float *)((int8_t *)d + 7*bd), row7);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs, const size_t nr, const size_t nc) {
|
||||||
|
assert(nr % 8 == 0);
|
||||||
|
assert(nc % 8 == 0);
|
||||||
|
for (size_t bi = 0; bi < nr; bi += 8) {
|
||||||
|
for (size_t bj = 0; bj < nc; bj += 8) {
|
||||||
|
ggml_transpose_8x8xpack4((int8_t *)d + bj*bd + bi*4, bd, (const int8_t *)s + bi*bs + bj*4, bs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
|
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
if (nrc == AMX_TILE_MN) {
|
||||||
|
assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0);
|
||||||
|
__tile1024i tileyt = {AMX_TILE_MN, AMX_TILE_K*4};
|
||||||
|
__tile1024i tilext = {AMX_TILE_K, AMX_TILE_MN*4};
|
||||||
|
__tile1024i tilezt = {AMX_TILE_MN, AMX_TILE_MN*sizeof(float)};
|
||||||
|
__tile_zero(&tilezt);
|
||||||
|
for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) {
|
||||||
|
ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)];
|
||||||
|
ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K);
|
||||||
|
__tile_loadd(&tileyt, y + i, by);
|
||||||
|
__tile_loadd(&tilext, axt, AMX_TILE_MN*4);
|
||||||
|
__tile_dpbf16ps(&tilezt, tileyt, tilext);
|
||||||
|
}
|
||||||
|
__tile_stored(s, bs*sizeof(float), tilezt);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
assert(nrc == 1);
|
assert(nrc == 1);
|
||||||
UNUSED(nrc);
|
UNUSED(nrc);
|
||||||
UNUSED(bx);
|
UNUSED(bx);
|
||||||
|
@ -12249,6 +12338,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
const bool src1_cont = ggml_is_contiguous(src1);
|
const bool src1_cont = ggml_is_contiguous(src1);
|
||||||
|
|
||||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||||
|
ggml_to_float_t const to_float = type_traits[type].to_float;
|
||||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
|
|
||||||
// broadcast factors
|
// broadcast factors
|
||||||
|
@ -12275,8 +12365,25 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||||
|
|
||||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||||
|
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||||
// 16 * 2, accounting for mmla kernels
|
// 16 * 2, accounting for mmla kernels
|
||||||
float tmp[32];
|
float tmp[32];
|
||||||
|
#elif defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
if (num_rows_per_vec_dot == AMX_TILE_MN) {
|
||||||
|
assert(AMX_TILE_MN <= blck_0 && AMX_TILE_MN <= blck_1);
|
||||||
|
assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0);
|
||||||
|
assert(src1->type == GGML_TYPE_F32);
|
||||||
|
}
|
||||||
|
// 16 * AMX_TILE_MN, accounting for amx kernels
|
||||||
|
float tmp[16*AMX_TILE_MN];
|
||||||
|
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
|
||||||
|
ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase);
|
||||||
|
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
|
||||||
|
float * xf32 = (float *) (wbase + 2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
|
||||||
|
xf32 = (float *) (((size_t)xf32 + 4095) & ~4095);
|
||||||
|
#else
|
||||||
|
float tmp[16];
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||||
|
@ -12293,24 +12400,39 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
const int64_t i2 = i12;
|
const int64_t i2 = i12;
|
||||||
const int64_t i3 = i13;
|
const int64_t i3 = i13;
|
||||||
|
|
||||||
const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
|
const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
|
||||||
|
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
if (num_rows_per_vec_dot == AMX_TILE_MN) {
|
||||||
|
const uint8_t * src1_col = (const uint8_t *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
||||||
|
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||||
|
for (int cn = 0; cn < AMX_TILE_MN; ++cn) {
|
||||||
|
to_float((const uint8_t *)src0_row + ir0*nb01 + cn*nb01, xf32, ne00);
|
||||||
|
ggml_fp32_to_bf16_row(xf32, xbf16 + cn*ne00, ne00);
|
||||||
|
ggml_fp32_to_bf16_row((const float *)(src1_col + cn*nb11), ybf16 + cn*ne00, ne00);
|
||||||
|
}
|
||||||
|
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], 16, xbf16, ne00*sizeof(ggml_bf16_t), ybf16, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
#endif
|
||||||
|
if (true) {
|
||||||
|
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||||
|
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||||
|
// the original src1 data pointer, so we should index using the indices directly
|
||||||
|
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||||
|
const char * src1_col = (const char*)wdata +
|
||||||
|
(src1_cont || src1->type != vec_dot_type
|
||||||
|
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
|
||||||
|
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
|
||||||
|
|
||||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||||
// the original src1 data pointer, so we should index using the indices directly
|
//}
|
||||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
|
||||||
const char * src1_col = (const char*)wdata +
|
|
||||||
(src1_cont || src1->type != vec_dot_type
|
|
||||||
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
|
|
||||||
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
|
|
||||||
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
|
||||||
|
|
||||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||||
//}
|
}
|
||||||
|
|
||||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
|
||||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||||
|
@ -12473,6 +12595,11 @@ UseGgmlGemm1:;
|
||||||
}
|
}
|
||||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||||
atomic_store(&state->shared->current_chunk, nth);
|
atomic_store(&state->shared->current_chunk, nth);
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (ne01 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
if (src1->type != vec_dot_type) {
|
if (src1->type != vec_dot_type) {
|
||||||
char * wdata = params->wdata;
|
char * wdata = params->wdata;
|
||||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
|
@ -12540,6 +12667,11 @@ UseGgmlGemm2:;
|
||||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||||
num_rows_per_vec_dot = 1;
|
num_rows_per_vec_dot = 1;
|
||||||
}
|
}
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (nr0 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) {
|
||||||
|
num_rows_per_vec_dot = AMX_TILE_MN;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Now select a reasonable chunk size.
|
// Now select a reasonable chunk size.
|
||||||
int chunk_size = 16;
|
int chunk_size = 16;
|
||||||
|
@ -19328,6 +19460,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||||
|
|
||||||
set_numa_thread_affinity(state->ith);
|
set_numa_thread_affinity(state->ith);
|
||||||
|
|
||||||
|
#if defined(__gnu_linux__)
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
int node_n = -1;
|
int node_n = -1;
|
||||||
int task_phase = GGML_TASK_TYPE_FINALIZE;
|
int task_phase = GGML_TASK_TYPE_FINALIZE;
|
||||||
|
|
||||||
|
@ -19525,6 +19663,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||||
* node->src[1]->ne[2]*node->src[1]->ne[3];
|
* node->src[1]->ne[2]*node->src[1]->ne[3];
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
|
#endif
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) {
|
||||||
|
cur = n_threads*(2*AMX_TILE_MN*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
|
||||||
|
} else
|
||||||
#endif
|
#endif
|
||||||
if (node->src[1]->type != vec_dot_type) {
|
if (node->src[1]->type != vec_dot_type) {
|
||||||
cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
|
cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue