remove unnecessary src tensor from ggml_get_rows_back
we don't need data of src[2] for computation, only to setup the correct output shape. remove dependency on src[2], so that allocator can work more freely. the computational graph is still completely determined, because the output shape is naturally included. this is similar to how ggml_reshape does it.
This commit is contained in:
parent
6c98640035
commit
65b0561637
1 changed files with 13 additions and 12 deletions
25
ggml.c
25
ggml.c
|
@ -6564,7 +6564,9 @@ struct ggml_tensor * ggml_get_rows_back(
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
result->src[1] = b;
|
result->src[1] = b;
|
||||||
result->src[2] = c;
|
// we don't need data of c for computation, only to setup the correct output shape.
|
||||||
|
// break dependency on c, so that allocator can work more freely.
|
||||||
|
//result->src[2] = c;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -11282,14 +11284,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
const struct ggml_tensor * opt0,
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(params->ith == 0);
|
GGML_ASSERT(params->ith == 0);
|
||||||
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(opt0));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_INIT) {
|
||||||
|
memset(dst->data, 0, ggml_nbytes(dst));
|
||||||
|
}
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
|
@ -11315,11 +11318,8 @@ static void ggml_compute_forward_get_rows_back_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
const struct ggml_tensor * opt0,
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(params->ith == 0);
|
GGML_ASSERT(params->ith == 0);
|
||||||
GGML_ASSERT(ggml_are_same_shape(opt0, dst));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(opt0));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
// ggml_compute_forward_dup_same_cont(params, opt0, dst);
|
||||||
|
@ -11353,16 +11353,15 @@ static void ggml_compute_forward_get_rows_back(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
const struct ggml_tensor * opt0,
|
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst);
|
ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst);
|
ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -14867,7 +14866,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS_BACK:
|
case GGML_OP_GET_ROWS_BACK:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG:
|
case GGML_OP_DIAG:
|
||||||
{
|
{
|
||||||
|
@ -15524,6 +15523,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
src0->grad =
|
src0->grad =
|
||||||
ggml_add_or_set(ctx, src0->grad,
|
ggml_add_or_set(ctx, src0->grad,
|
||||||
|
// last ggml_get_rows_back argument src0->grad is only
|
||||||
|
// necessary to setup correct output shape
|
||||||
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue