use larger block size

This commit is contained in:
Reinforce-II 2024-05-23 01:17:51 +08:00
parent 3047229758
commit 9a166331e0

55
ggml.c
View file

@ -46,6 +46,7 @@
#undef GGML_USE_LLAMAFILE #undef GGML_USE_LLAMAFILE
#define AMX_TILE_MN 16 #define AMX_TILE_MN 16
#define AMX_TILE_K 16 #define AMX_TILE_K 16
#define AMX_BLCK_SIZE 64
#endif #endif
#ifdef GGML_USE_LLAMAFILE #ifdef GGML_USE_LLAMAFILE
@ -12359,8 +12360,13 @@ static void ggml_compute_forward_mul_mat_one_chunk(
assert(ne13 % ne03 == 0); assert(ne13 % ne03 == 0);
// block-tiling attempt // block-tiling attempt
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
const int64_t blck_0 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
const int64_t blck_1 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
#else
const int64_t blck_0 = 16; const int64_t blck_0 = 16;
const int64_t blck_1 = 16; const int64_t blck_1 = 16;
#endif
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
@ -12374,12 +12380,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0); assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0);
assert(src1->type == GGML_TYPE_F32); assert(src1->type == GGML_TYPE_F32);
} }
// 16 * AMX_TILE_MN, accounting for amx kernels // AMX_BLCK_SIZE * AMX_TILE_MN, accounting for amx kernels
float tmp[16*AMX_TILE_MN]; float tmp[AMX_BLCK_SIZE*AMX_TILE_MN];
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096); uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase); ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase);
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)); ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
float * xf32 = (float *) (wbase + 2*AMX_TILE_MN*ne00*sizeof(ggml_bf16_t)); float * xf32 = (float *) (wbase + 2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
xf32 = (float *) (((size_t)xf32 + 4095) & ~4095); xf32 = (float *) (((size_t)xf32 + 4095) & ~4095);
#else #else
float tmp[16]; float tmp[16];
@ -12387,6 +12393,27 @@ static void ggml_compute_forward_mul_mat_one_chunk(
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
const int64_t ii13 = (iir1 / (ne12 * ne1));
const int64_t ii12 = (iir1 - ii13 * ne12 * ne1) / ne1;
const int64_t ii11 = (iir1 - ii13 * ne12 * ne1 - ii12 * ne1);
// broadcast src0 into src1
const int64_t ii03 = ii13 / r3;
const int64_t ii02 = ii12 / r2;
const char * src0_row = (const char*)src0->data + (0 + ii02 * nb02 + ii03 * nb03);
const uint8_t * src1_col = (const uint8_t *)src1->data + ii11 * nb11 + ii12 * nb12 + ii13 * nb13;
for (int i = 0; i < blck_0 && iir0 + i < ir0_end; ++i) {
to_float((const uint8_t *)src0_row + iir0*nb01 + i*nb01, xf32, ne00);
ggml_fp32_to_bf16_row(xf32, xbf16 + i*ne00, ne00);
}
for (int i = 0; i < blck_1 && iir1 + i < ir1_end; ++i) {
ggml_fp32_to_bf16_row((const float *)(src1_col + i*nb11), ybf16 + i*ne00, ne00);
}
}
#endif
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
const int64_t i13 = (ir1 / (ne12 * ne1)); const int64_t i13 = (ir1 / (ne12 * ne1));
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
@ -12404,14 +12431,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
#if defined(__AMX_TILE__) && defined(__AMX_BF16__) #if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) { if (num_rows_per_vec_dot == AMX_TILE_MN) {
const uint8_t * src1_col = (const uint8_t *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13;
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
for (int cn = 0; cn < AMX_TILE_MN; ++cn) { ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], blck_0, xbf16 + (ir0-iir0)*ne00, ne00*sizeof(ggml_bf16_t), ybf16 + (ir1-iir1)*ne00, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
to_float((const uint8_t *)src0_row + ir0*nb01 + cn*nb01, xf32, ne00);
ggml_fp32_to_bf16_row(xf32, xbf16 + cn*ne00, ne00);
ggml_fp32_to_bf16_row((const float *)(src1_col + cn*nb11), ybf16 + cn*ne00, ne00);
}
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], 16, xbf16, ne00*sizeof(ggml_bf16_t), ybf16, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
} }
} }
else else
@ -12431,12 +12452,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
//} //}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? blck_0 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
} }
} }
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
} }
} }
} }
@ -12674,7 +12695,11 @@ UseGgmlGemm2:;
#endif #endif
// Now select a reasonable chunk size. // Now select a reasonable chunk size.
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
int chunk_size = AMX_BLCK_SIZE;
#else
int chunk_size = 16; int chunk_size = 16;
#endif
// We need to step up the size if it's small // We need to step up the size if it's small
if (nr0 == 1 || nr1 == 1) { if (nr0 == 1 || nr1 == 1) {
@ -19666,7 +19691,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
#endif #endif
#if defined(__AMX_TILE__) && defined(__AMX_BF16__) #if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) { if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) {
cur = n_threads*(2*AMX_TILE_MN*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096); cur = n_threads*(2*AMX_BLCK_SIZE*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
} else } else
#endif #endif
if (node->src[1]->type != vec_dot_type) { if (node->src[1]->type != vec_dot_type) {