ggml : avoid recomputing alibi slopes (CPU)
This commit is contained in:
parent
8084d55440
commit
7e0c3778fb
1 changed files with 25 additions and 28 deletions
21
ggml.c
21
ggml.c
|
@ -11694,14 +11694,8 @@ static void ggml_compute_forward_alibi_f32(
|
|||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||
|
||||
for (int64_t i = 0; i < ne0; i++) {
|
||||
for (int64_t j = 0; j < ne1; j++) {
|
||||
for (int64_t k = 0; k < ne2_ne3; k++) {
|
||||
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
|
||||
// TODO: k*nb2 or k*nb3
|
||||
|
||||
float m_k;
|
||||
|
||||
if (k < n_heads_log2_floor) {
|
||||
|
@ -11710,6 +11704,10 @@ static void ggml_compute_forward_alibi_f32(
|
|||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < ne0; i++) {
|
||||
for (int64_t j = 0; j < ne1; j++) {
|
||||
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
pdst[0] = i * m_k + src[0];
|
||||
}
|
||||
}
|
||||
|
@ -11754,14 +11752,8 @@ static void ggml_compute_forward_alibi_f16(
|
|||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||
|
||||
for (int i = 0; i < ne0; i++) {
|
||||
for (int j = 0; j < ne1; j++) {
|
||||
for (int k = 0; k < ne2_ne3; k++) {
|
||||
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
|
||||
// TODO: k*nb2 or k*nb3
|
||||
|
||||
float m_k;
|
||||
|
||||
if (k < n_heads_log2_floor) {
|
||||
|
@ -11770,6 +11762,11 @@ static void ggml_compute_forward_alibi_f16(
|
|||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < ne0; i++) {
|
||||
for (int j = 0; j < ne1; j++) {
|
||||
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
|
||||
// we return F32
|
||||
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue