fix ggml_acc_or_set to return tensor of correct shape

This commit is contained in:
xaedes 2023-08-29 21:02:10 +02:00
parent b1aa26f718
commit a76e66ac8d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

3
ggml.c
View file

@ -16255,7 +16255,8 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) { static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
if (hash_contains(zero_table, a)) { if (hash_contains(zero_table, a)) {
return b; struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
} else { } else {
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
} }