use larger block size
This commit is contained in:
parent
3047229758
commit
9a166331e0
1 changed files with 40 additions and 15 deletions
55
ggml.c
55
ggml.c
|
@ -46,6 +46,7 @@
|
||||||
#undef GGML_USE_LLAMAFILE
|
#undef GGML_USE_LLAMAFILE
|
||||||
#define AMX_TILE_MN 16
|
#define AMX_TILE_MN 16
|
||||||
#define AMX_TILE_K 16
|
#define AMX_TILE_K 16
|
||||||
|
#define AMX_BLCK_SIZE 64
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_LLAMAFILE
|
#ifdef GGML_USE_LLAMAFILE
|
||||||
|
@ -12359,8 +12360,13 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
assert(ne13 % ne03 == 0);
|
assert(ne13 % ne03 == 0);
|
||||||
|
|
||||||
// block-tiling attempt
|
// block-tiling attempt
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
const int64_t blck_0 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
|
||||||
|
const int64_t blck_1 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
|
||||||
|
#else
|
||||||
const int64_t blck_0 = 16;
|
const int64_t blck_0 = 16;
|
||||||
const int64_t blck_1 = 16;
|
const int64_t blck_1 = 16;
|
||||||
|
#endif
|
||||||
|
|
||||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||||
|
|
||||||
|
@ -12374,12 +12380,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0);
|
assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0);
|
||||||
assert(src1->type == GGML_TYPE_F32);
|
assert(src1->type == GGML_TYPE_F32);
|
||||||
}
|
}
|
||||||
// 16 * AMX_TILE_MN, accounting for amx kernels
|
// AMX_BLCK_SIZE * AMX_TILE_MN, accounting for amx kernels
|
||||||
float tmp[16*AMX_TILE_MN];
|
float tmp[AMX_BLCK_SIZE*AMX_TILE_MN];
|
||||||
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
|
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
|
||||||
ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase);
|
ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase);
|
||||||
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
|
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
|
||||||
float * xf32 = (float *) (wbase + 2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t));
|
float * xf32 = (float *) (wbase + 2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
|
||||||
xf32 = (float *) (((size_t)xf32 + 4095) & ~4095);
|
xf32 = (float *) (((size_t)xf32 + 4095) & ~4095);
|
||||||
#else
|
#else
|
||||||
float tmp[16];
|
float tmp[16];
|
||||||
|
@ -12387,6 +12393,27 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
|
|
||||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
if (num_rows_per_vec_dot == AMX_TILE_MN) {
|
||||||
|
const int64_t ii13 = (iir1 / (ne12 * ne1));
|
||||||
|
const int64_t ii12 = (iir1 - ii13 * ne12 * ne1) / ne1;
|
||||||
|
const int64_t ii11 = (iir1 - ii13 * ne12 * ne1 - ii12 * ne1);
|
||||||
|
|
||||||
|
// broadcast src0 into src1
|
||||||
|
const int64_t ii03 = ii13 / r3;
|
||||||
|
const int64_t ii02 = ii12 / r2;
|
||||||
|
|
||||||
|
const char * src0_row = (const char*)src0->data + (0 + ii02 * nb02 + ii03 * nb03);
|
||||||
|
const uint8_t * src1_col = (const uint8_t *)src1->data + ii11 * nb11 + ii12 * nb12 + ii13 * nb13;
|
||||||
|
for (int i = 0; i < blck_0 && iir0 + i < ir0_end; ++i) {
|
||||||
|
to_float((const uint8_t *)src0_row + iir0*nb01 + i*nb01, xf32, ne00);
|
||||||
|
ggml_fp32_to_bf16_row(xf32, xbf16 + i*ne00, ne00);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < blck_1 && iir1 + i < ir1_end; ++i) {
|
||||||
|
ggml_fp32_to_bf16_row((const float *)(src1_col + i*nb11), ybf16 + i*ne00, ne00);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
||||||
const int64_t i13 = (ir1 / (ne12 * ne1));
|
const int64_t i13 = (ir1 / (ne12 * ne1));
|
||||||
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
|
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
|
||||||
|
@ -12404,14 +12431,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||||
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
if (num_rows_per_vec_dot == AMX_TILE_MN) {
|
if (num_rows_per_vec_dot == AMX_TILE_MN) {
|
||||||
const uint8_t * src1_col = (const uint8_t *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13;
|
|
||||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||||
for (int cn = 0; cn < AMX_TILE_MN; ++cn) {
|
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], blck_0, xbf16 + (ir0-iir0)*ne00, ne00*sizeof(ggml_bf16_t), ybf16 + (ir1-iir1)*ne00, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
|
||||||
to_float((const uint8_t *)src0_row + ir0*nb01 + cn*nb01, xf32, ne00);
|
|
||||||
ggml_fp32_to_bf16_row(xf32, xbf16 + cn*ne00, ne00);
|
|
||||||
ggml_fp32_to_bf16_row((const float *)(src1_col + cn*nb11), ybf16 + cn*ne00, ne00);
|
|
||||||
}
|
|
||||||
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], 16, xbf16, ne00*sizeof(ggml_bf16_t), ybf16, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -12431,12 +12452,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
//}
|
//}
|
||||||
|
|
||||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? blck_0 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||||
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12674,7 +12695,11 @@ UseGgmlGemm2:;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Now select a reasonable chunk size.
|
// Now select a reasonable chunk size.
|
||||||
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
|
int chunk_size = AMX_BLCK_SIZE;
|
||||||
|
#else
|
||||||
int chunk_size = 16;
|
int chunk_size = 16;
|
||||||
|
#endif
|
||||||
|
|
||||||
// We need to step up the size if it's small
|
// We need to step up the size if it's small
|
||||||
if (nr0 == 1 || nr1 == 1) {
|
if (nr0 == 1 || nr1 == 1) {
|
||||||
|
@ -19666,7 +19691,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||||
#endif
|
#endif
|
||||||
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
|
||||||
if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) {
|
if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) {
|
||||||
cur = n_threads*(2*AMX_TILE_MN*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
|
cur = n_threads*(2*AMX_BLCK_SIZE*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
if (node->src[1]->type != vec_dot_type) {
|
if (node->src[1]->type != vec_dot_type) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue