fix view backward pass
add nb parameters to add_at like in view. together with offset they define how to view dst and src0 during the add_at operation.
This commit is contained in:
parent
f0302fa71b
commit
84436383eb
2 changed files with 111 additions and 272 deletions
371
ggml.c
371
ggml.c
|
@ -5054,9 +5054,14 @@ struct ggml_tensor * ggml_add_at_impl(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t offset,
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
|
@ -5065,12 +5070,18 @@ struct ggml_tensor * ggml_add_at_impl(
|
|||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
|
||||
((int32_t *) c->data)[0] = offset;
|
||||
((int32_t *) c->data)[1] = nb1;
|
||||
((int32_t *) c->data)[2] = nb2;
|
||||
((int32_t *) c->data)[3] = nb3;
|
||||
((int32_t *) c->data)[4] = inplace ? 1 : 0;
|
||||
|
||||
result->op = GGML_OP_ADD_AT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = b;
|
||||
memcpy(result->padding, &offset, sizeof(size_t));
|
||||
result->opt[0] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -5079,16 +5090,22 @@ struct ggml_tensor * ggml_add_at(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t offset) {
|
||||
return ggml_add_at_impl(ctx, a, b, offset, false);
|
||||
size_t offset,
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3) {
|
||||
return ggml_add_at_impl(ctx, a, b, offset, nb1, nb2, nb3, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_add_at_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t offset) {
|
||||
return ggml_add_at_impl(ctx, a, b, offset, true);
|
||||
size_t offset,
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3) {
|
||||
return ggml_add_at_impl(ctx, a, b, offset, nb1, nb2, nb3, true);
|
||||
}
|
||||
|
||||
// ggml_sub
|
||||
|
@ -5951,7 +5968,7 @@ struct ggml_tensor * ggml_view_1d(
|
|||
result->src1 = NULL;
|
||||
|
||||
if (is_node) {
|
||||
memcpy(result->padding, &offset, sizeof(size_t));
|
||||
memcpy(result->padding, &offset, sizeof(offset));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -5987,7 +6004,7 @@ struct ggml_tensor * ggml_view_2d(
|
|||
result->src1 = NULL;
|
||||
|
||||
if (is_node) {
|
||||
memcpy(result->padding, &offset, sizeof(size_t));
|
||||
memcpy(result->padding, &offset, sizeof(offset));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -6025,7 +6042,7 @@ struct ggml_tensor * ggml_view_3d(
|
|||
result->src1 = NULL;
|
||||
|
||||
if (is_node) {
|
||||
memcpy(result->padding, &offset, sizeof(size_t));
|
||||
memcpy(result->padding, &offset, sizeof(offset));
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -7946,12 +7963,30 @@ static void ggml_compute_forward_add_at_f32(
|
|||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst,
|
||||
size_t offset) {
|
||||
// GGML_ASSERT(ggml_are_same_shape(src0, src1)); // TODO: assert that offset+len(src1) <= len(src1)
|
||||
const struct ggml_tensor * opt0,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(opt0->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ggml_nelements(opt0) == 5);
|
||||
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
// view src0 and dst with these strides and data offset inbytes during add_at
|
||||
// nb0 is implicitely element_size because src0 and dst are contiguous
|
||||
size_t offset = ((int32_t *) opt0->data)[0];
|
||||
size_t nb1 = ((int32_t *) opt0->data)[1];
|
||||
size_t nb2 = ((int32_t *) opt0->data)[2];
|
||||
size_t nb3 = ((int32_t *) opt0->data)[3];
|
||||
bool inplace = (bool) ((int32_t *) opt0->data)[4];
|
||||
|
||||
if (!inplace && (params->type == GGML_TASK_INIT)) {
|
||||
// memcpy needs to be synchronized across threads to avoid race conditions.
|
||||
// => do it in INIT phase
|
||||
memcpy(
|
||||
((char *) dst->data),
|
||||
((char *) src0->data),
|
||||
ggml_nbytes(dst));
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
|
@ -7960,228 +7995,32 @@ static void ggml_compute_forward_add_at_f32(
|
|||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int n = ggml_nrows(src1);
|
||||
const int nr = ggml_nrows(src1);
|
||||
const int nc = src1->ne[0];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
if (nb10 == sizeof(float)) {
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_vadd(
|
||||
(float *) ((char *) src0->data + j*nb01 + offset), 1,
|
||||
(float *) ((char *) src1->data + j*nb11), 1,
|
||||
(float *) ((char *) dst->data + j*nb1 + offset), 1, nc);
|
||||
#else
|
||||
ggml_vec_add_f32(nc,
|
||||
(float *) ((char *) dst->data + j*nb1 + offset),
|
||||
(float *) ((char *) src0->data + j*nb01 + offset),
|
||||
(float *) ((char *) src1->data + j*nb11));
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
// src1 is not contiguous
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
float * dst_ptr = (float *) ((char *) dst->data + j*nb1 + offset);
|
||||
float * src0_ptr = (float *) ((char *) src0->data + j*nb01 + offset);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
|
||||
dst_ptr[i] = src0_ptr[i] + *src1_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_at_f16_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst,
|
||||
size_t offset) {
|
||||
// GGML_ASSERT(ggml_are_same_shape(src0, src1)); // TODO: assert that offset+len(src1) <= len(src1)
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(float)) {
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1 + offset);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01 + offset);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// src1 is not contiguous
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_at_f16_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst,
|
||||
size_t offset) {
|
||||
// GGML_ASSERT(ggml_are_same_shape(src0, src1)); // TODO: assert that offset+len(src1) <= len(src1)
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(ggml_fp16_t)) {
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1 + offset);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01 + offset);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// src1 is not contiguous
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_at_q_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst,
|
||||
size_t offset) {
|
||||
// GGML_ASSERT(ggml_are_same_shape(src0, src1)); // TODO: assert that offset+len(src1) <= len(src1)
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
ggml_compute_forward_dup_same_cont(params, src0, dst);
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
//const int64_t ne10 = src1->ne[0];
|
||||
//const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
|
||||
//const int64_t ne0 = dst->ne[0];
|
||||
//const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
const size_t nb12 = src1->nb[2];
|
||||
const size_t nb13 = src1->nb[3];
|
||||
|
||||
const int nb00 = src0->nb[0];
|
||||
const int nb01 = src0->nb[1];
|
||||
const int nb02 = src0->nb[2];
|
||||
const int nb03 = src0->nb[3];
|
||||
// src0 and dst as viewed during add_at
|
||||
const size_t nb0 = ggml_element_size(src0);
|
||||
|
||||
const int nb10 = src1->nb[0];
|
||||
const int nb11 = src1->nb[1];
|
||||
const int nb12 = src1->nb[2];
|
||||
const int nb13 = src1->nb[3];
|
||||
const size_t nb00 = nb0;
|
||||
const size_t nb01 = nb1;
|
||||
const size_t nb02 = nb2;
|
||||
const size_t nb03 = nb3;
|
||||
|
||||
const int nb0 = dst->nb[0];
|
||||
const int nb1 = dst->nb[1];
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
|
||||
GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
||||
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(ggml_is_quantized(src0->type));
|
||||
GGML_ASSERT(dst->type == src0->type);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// total rows in src0
|
||||
const int nr = ne01*ne02*ne03;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
|
@ -8189,35 +8028,24 @@ static void ggml_compute_forward_add_at_q_f32(
|
|||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0 indices
|
||||
const int i03 = ir/(ne02*ne01);
|
||||
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
// src0 and dst are viewed with shape of src1 and offset
|
||||
// => same indices
|
||||
const int i3 = ir/(ne12*ne11);
|
||||
const int i2 = (ir - i3*ne12*ne11)/ne11;
|
||||
const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
|
||||
|
||||
// src1 and dst are same shape as src0 => same indices
|
||||
const int i13 = i03;
|
||||
const int i12 = i02;
|
||||
const int i11 = i01;
|
||||
|
||||
const int i3 = i03;
|
||||
const int i2 = i02;
|
||||
const int i1 = i01;
|
||||
|
||||
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03) + offset);
|
||||
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
||||
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0) + offset);
|
||||
|
||||
assert(ne00 % 32 == 0);
|
||||
|
||||
// unquantize row from src0 to temp buffer
|
||||
dequantize_row_q(src0_row, wdata, ne00);
|
||||
// add src1
|
||||
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
||||
// quantize row to dst
|
||||
quantize_row_q(wdata, dst_row, ne00);
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_vadd(
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
|
||||
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
|
||||
#else
|
||||
ggml_vec_add_f32(nc,
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
|
||||
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8225,33 +8053,19 @@ static void ggml_compute_forward_add_at(
|
|||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * opt0,
|
||||
struct ggml_tensor * dst) {
|
||||
size_t offset;
|
||||
memcpy(&offset, dst->padding, sizeof(offset));
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_add_at_f32(params, src0, src1, dst, offset);
|
||||
ggml_compute_forward_add_at_f32(params, src0, src1, opt0, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
if (src1->type == GGML_TYPE_F16) {
|
||||
ggml_compute_forward_add_at_f16_f16(params, src0, src1, dst, offset);
|
||||
}
|
||||
else if (src1->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_add_at_f16_f32(params, src0, src1, dst, offset);
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_2:
|
||||
case GGML_TYPE_Q4_3:
|
||||
{
|
||||
ggml_compute_forward_add_at_q_f32(params, src0, src1, dst, offset);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -12749,7 +12563,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
} break;
|
||||
case GGML_OP_ADD_AT:
|
||||
{
|
||||
ggml_compute_forward_add_at(params, tensor->src0, tensor->src1, tensor);
|
||||
ggml_compute_forward_add_at(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_SUB:
|
||||
{
|
||||
|
@ -13283,7 +13097,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
if (src0->grad) {
|
||||
size_t offset;
|
||||
memcpy(&offset, tensor->padding, sizeof(offset));
|
||||
src0->grad = ggml_add_at_impl(ctx, src0->grad, tensor->grad, offset, inplace);
|
||||
|
||||
size_t nb1 = tensor->nb[1];
|
||||
size_t nb2 = tensor->nb[2];
|
||||
size_t nb3 = tensor->nb[3];
|
||||
|
||||
if (src0->type != src0->grad->type) {
|
||||
// gradient is typically F32, but src0 could be other type
|
||||
size_t ng = ggml_element_size(src0->grad);
|
||||
size_t n0 = ggml_element_size(src0);
|
||||
GGML_ASSERT(offset % n0 == 0);
|
||||
GGML_ASSERT(nb1 % n0 == 0);
|
||||
GGML_ASSERT(nb2 % n0 == 0);
|
||||
GGML_ASSERT(nb3 % n0 == 0);
|
||||
offset = (offset / n0) * ng;
|
||||
nb1 = (nb1 / n0) * ng;
|
||||
nb2 = (nb2 / n0) * ng;
|
||||
nb3 = (nb3 / n0) * ng;
|
||||
}
|
||||
|
||||
src0->grad = ggml_add_at_impl(ctx, src0->grad, tensor->grad, offset, nb1, nb2, nb3, inplace);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_PERMUTE:
|
||||
|
|
12
ggml.h
12
ggml.h
|
@ -491,19 +491,25 @@ extern "C" {
|
|||
GGML_API struct ggml_tensor * ggml_add1(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add_at(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t offset);
|
||||
size_t offset,
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add_at_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t offset);
|
||||
size_t offset,
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sub(
|
||||
struct ggml_context * ctx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue