add sanity check to ggml_compute_backward, asserting the correct shape of gradients

This commit is contained in:
xaedes 2023-08-29 21:01:17 +02:00
parent 5fcfa7e49e
commit b1aa26f718
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

6
ggml.c
View file

@ -17147,6 +17147,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
} }
for (int i = 0; i < GGML_MAX_SRC; ++i) {
if (tensor->src[i] && tensor->src[i]->grad) {
GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
}
}
} }
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {