ggml : change ggml_scale to take a float instead of tensor

This commit is contained in:
Georgi Gerganov 2023-12-21 20:50:24 +02:00
parent 8fe03ffdda
commit 199f6bdc46
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 68 additions and 186 deletions

View file

@ -575,10 +575,7 @@ static struct ggml_tensor * forward(
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]
@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch(
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
// KQ_masked = mask_past(KQ_scaled)
@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora(
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, 1]