bug fix for add_at forward

required for view backward pass

src0 values must be copied to dst, because during addition we don't touch all dst elements in contrast to the normal add function.
This commit is contained in:
xaedes 2023-04-27 16:58:22 +02:00
parent 83fa6b3bcb
commit cecd6c7665
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

16
ggml.c
View file

@ -5055,7 +5055,8 @@ struct ggml_tensor * ggml_add_at_impl(
struct ggml_tensor * b,
size_t offset,
bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
GGML_ASSERT(ggml_is_contiguous(a));
bool is_node = false;
@ -7860,8 +7861,8 @@ static void ggml_compute_forward_add_at_f32(
const int ith = params->ith;
const int nth = params->nth;
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
const int n = ggml_nrows(src1);
const int nc = src1->ne[0];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];
@ -7884,7 +7885,7 @@ static void ggml_compute_forward_add_at_f32(
(float *) ((char *) dst->data + j*nb1 + offset), 1, nc);
#else
ggml_vec_add_f32(nc,
(float *) ((char *) dst->data + j*nb1 + offset),
(float *) ((char *) dst->data + j*nb1 + offset),
(float *) ((char *) src0->data + j*nb01 + offset),
(float *) ((char *) src1->data + j*nb11));
#endif
@ -7892,7 +7893,7 @@ static void ggml_compute_forward_add_at_f32(
} else {
// src1 is not contiguous
for (int j = ith; j < n; j += nth) {
float * dst_ptr = (float *) ((char *) dst->data + j*nb1 + offset);
float * dst_ptr = (float *) ((char *) dst->data + j*nb1 + offset);
float * src0_ptr = (float *) ((char *) src0->data + j*nb01 + offset);
for (int i = 0; i < nc; i++) {
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
@ -8121,7 +8122,8 @@ static void ggml_compute_forward_add_at(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
size_t offset;
memcpy(&offset, dst->padding, sizeof(size_t));
memcpy(&offset, dst->padding, sizeof(offset));
ggml_compute_forward_dup_same_cont(params, src0, dst);
switch (src0->type) {
case GGML_TYPE_F32:
{
@ -12963,7 +12965,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
size_t offset;
memcpy(&offset, tensor->padding, sizeof(size_t));
memcpy(&offset, tensor->padding, sizeof(offset));
src0->grad = ggml_add_at_impl(ctx, src0->grad, tensor->grad, offset, inplace);
}
} break;