add ggml_add_cast API function
this function works like ggml_add, but accepts a data type for the resulting tensor. only supported for quantized src0 input.
This commit is contained in:
parent
f80e245d7b
commit
9198b24e4e
2 changed files with 51 additions and 2 deletions
45
ggml.c
45
ggml.c
|
@ -5115,6 +5115,44 @@ struct ggml_tensor * ggml_add_inplace(
|
||||||
return ggml_add_impl(ctx, a, b, true);
|
return ggml_add_impl(ctx, a, b, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_add_cast
|
||||||
|
|
||||||
|
static struct ggml_tensor * ggml_add_cast_impl(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
enum ggml_type type) {
|
||||||
|
// TODO: support less-strict constraint
|
||||||
|
// GGML_ASSERT(ggml_can_repeat(b, a));
|
||||||
|
GGML_ASSERT(ggml_can_repeat_rows(b, a));
|
||||||
|
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (a->grad || b->grad) {
|
||||||
|
// TODO: support backward pass for broadcasting
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne);
|
||||||
|
|
||||||
|
result->op = GGML_OP_ADD;
|
||||||
|
result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL;
|
||||||
|
result->src[0] = a;
|
||||||
|
result->src[1] = b;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_add_cast(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
enum ggml_type type) {
|
||||||
|
return ggml_add_impl(ctx, a, b, false);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_add1
|
// ggml_add1
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_add1_impl(
|
static struct ggml_tensor * ggml_add1_impl(
|
||||||
|
@ -8317,8 +8355,9 @@ static void ggml_compute_forward_add_q_f32(
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const enum ggml_type type = src0->type;
|
const enum ggml_type type = src0->type;
|
||||||
|
const enum ggml_type dtype = dst->type;
|
||||||
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
||||||
ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
|
ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
|
||||||
|
|
||||||
// we don't support permuted src0 or src1
|
// we don't support permuted src0 or src1
|
||||||
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
|
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
|
||||||
|
@ -8368,7 +8407,11 @@ static void ggml_compute_forward_add_q_f32(
|
||||||
// add src1
|
// add src1
|
||||||
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
||||||
// quantize row to dst
|
// quantize row to dst
|
||||||
|
if (quantize_row_q != NULL) {
|
||||||
quantize_row_q(wdata, dst_row, ne00);
|
quantize_row_q(wdata, dst_row, ne00);
|
||||||
|
} else {
|
||||||
|
memcpy(dst_row, wdata, ne0*nb0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
6
ggml.h
6
ggml.h
|
@ -670,6 +670,12 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_add_cast(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
enum ggml_type type);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_add1(
|
GGML_API struct ggml_tensor * ggml_add1(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue