refact : fix convert script + zero out KV cache to avoid nans (#3523)

* refact : fix convert script + zero out KV cache to avoid nans

* ggml : silu(-inf) should never happen

* metal : assert various kernel requirements
This commit is contained in:
Georgi Gerganov 2023-10-09 14:32:17 +03:00 committed by GitHub
parent dcc09d2596
commit fcca0a7004
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 91 deletions

27
ggml.c
View file

@ -11233,7 +11233,7 @@ static void ggml_compute_forward_silu_f32(
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
@ -13066,17 +13066,17 @@ static void ggml_compute_forward_alibi_f32(
assert(n_past >= 0);
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0);
const int ne2_ne3 = n/ne1; // ne2*ne3
const int64_t n = ggml_nrows(src0);
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
const int nb0 = src0->nb[0];
const int nb1 = src0->nb[1];
const int nb2 = src0->nb[2];
const size_t nb0 = src0->nb[0];
const size_t nb1 = src0->nb[1];
const size_t nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
GGML_ASSERT(nb0 == sizeof(float));
@ -13088,9 +13088,9 @@ static void ggml_compute_forward_alibi_f32(
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
for (int i = 0; i < ne0; i++) {
for (int j = 0; j < ne1; j++) {
for (int k = 0; k < ne2_ne3; k++) {
for (int64_t i = 0; i < ne0; i++) {
for (int64_t j = 0; j < ne1; j++) {
for (int64_t k = 0; k < ne2_ne3; k++) {
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
@ -13105,7 +13105,6 @@ static void ggml_compute_forward_alibi_f32(
}
pdst[0] = i * m_k + src[0];
}
}
}