add ggml_add_cast API function

this function works like ggml_add, but accepts a data type for the resulting tensor.
only supported for quantized src0 input.
This commit is contained in:
xaedes 2023-08-16 23:50:46 +02:00
parent f80e245d7b
commit 9198b24e4e
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 51 additions and 2 deletions

45
ggml.c
View file

@ -5115,6 +5115,44 @@ struct ggml_tensor * ggml_add_inplace(
return ggml_add_impl(ctx, a, b, true); return ggml_add_impl(ctx, a, b, true);
} }
// ggml_add_cast
static struct ggml_tensor * ggml_add_cast_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_type type) {
// TODO: support less-strict constraint
// GGML_ASSERT(ggml_can_repeat(b, a));
GGML_ASSERT(ggml_can_repeat_rows(b, a));
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
bool is_node = false;
if (a->grad || b->grad) {
// TODO: support backward pass for broadcasting
GGML_ASSERT(ggml_are_same_shape(a, b));
is_node = true;
}
struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
result->op = GGML_OP_ADD;
result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_tensor * ggml_add_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_type type) {
return ggml_add_impl(ctx, a, b, false);
}
// ggml_add1 // ggml_add1
static struct ggml_tensor * ggml_add1_impl( static struct ggml_tensor * ggml_add1_impl(
@ -8317,8 +8355,9 @@ static void ggml_compute_forward_add_q_f32(
const int nth = params->nth; const int nth = params->nth;
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
const enum ggml_type dtype = dst->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
ggml_from_float_t const quantize_row_q = type_traits[type].from_float; ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]); GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
@ -8368,7 +8407,11 @@ static void ggml_compute_forward_add_q_f32(
// add src1 // add src1
ggml_vec_acc_f32(ne00, wdata, src1_row); ggml_vec_acc_f32(ne00, wdata, src1_row);
// quantize row to dst // quantize row to dst
if (quantize_row_q != NULL) {
quantize_row_q(wdata, dst_row, ne00); quantize_row_q(wdata, dst_row, ne00);
} else {
memcpy(dst_row, wdata, ne0*nb0);
}
} }
} }

6
ggml.h
View file

@ -670,6 +670,12 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_add_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_type type);
GGML_API struct ggml_tensor * ggml_add1( GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,