remove unnecessary src tensor from ggml_get_rows_back

we don't need data of src[2] for computation, only to setup the correct output shape.
remove dependency on src[2], so that allocator can work more freely.

the computational graph is still completely determined, because the output shape is naturally included.
this is similar to how ggml_reshape does it.
This commit is contained in:
xaedes 2023-08-18 20:25:03 +02:00
parent 6c98640035
commit 65b0561637
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

25
ggml.c
View file

@ -6564,7 +6564,9 @@ struct ggml_tensor * ggml_get_rows_back(
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a; result->src[0] = a;
result->src[1] = b; result->src[1] = b;
result->src[2] = c; // we don't need data of c for computation, only to setup the correct output shape.
// break dependency on c, so that allocator can work more freely.
//result->src[2] = c;
return result; return result;
} }
@ -11282,14 +11284,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0); GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_is_contiguous(dst));
ggml_compute_forward_dup_same_cont(params, opt0, dst); // ggml_compute_forward_dup_same_cont(params, opt0, dst);
if (params->type == GGML_TASK_INIT) {
memset(dst->data, 0, ggml_nbytes(dst));
}
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
@ -11315,11 +11318,8 @@ static void ggml_compute_forward_get_rows_back_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0); GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
GGML_ASSERT(ggml_is_contiguous(opt0));
GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_is_contiguous(dst));
// ggml_compute_forward_dup_same_cont(params, opt0, dst); // ggml_compute_forward_dup_same_cont(params, opt0, dst);
@ -11353,16 +11353,15 @@ static void ggml_compute_forward_get_rows_back(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst); ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst); ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
} break; } break;
default: default:
{ {
@ -14867,7 +14866,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_GET_ROWS_BACK: case GGML_OP_GET_ROWS_BACK:
{ {
ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
} break; } break;
case GGML_OP_DIAG: case GGML_OP_DIAG:
{ {
@ -15524,6 +15523,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_or_set(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
// last ggml_get_rows_back argument src0->grad is only
// necessary to setup correct output shape
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
zero_table); zero_table);
} }