Fuse matrix multiplication + SiLU

This commit is contained in:
JohannesGaessler 2024-02-08 11:52:13 +01:00
parent a6e514a85f
commit 3cf123e8f6
3 changed files with 39 additions and 2 deletions

36
ggml.c
View file

@ -1817,6 +1817,7 @@ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"IDENTITY",
"ABS",
"SGN",
"NEG",
@ -1831,7 +1832,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID",
};
static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@ -9992,6 +9993,7 @@ static void ggml_compute_forward_mul_mat(
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
enum ggml_unary_op const activation = (enum ggml_unary_op) ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
@ -10197,7 +10199,20 @@ static void ggml_compute_forward_mul_mat(
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
float * dst_ptr = &dst_col[iir0];
const int64_t n = MIN(iir0 + blck_0, ir011) - iir0;
switch (activation) {
case GGML_UNARY_OP_IDENTITY:
memcpy(dst_ptr, tmp, n*sizeof(float));
break;
case GGML_UNARY_OP_SILU:
ggml_vec_silu_f32(n, dst_ptr, tmp);
break;
default:
GGML_ASSERT(false);
break;
}
}
}
}
@ -14220,6 +14235,8 @@ static void ggml_compute_forward_unary(
const enum ggml_unary_op op = ggml_get_unary_op(dst);
switch (op) {
case GGML_UNARY_OP_IDENTITY:
break; // nothing to do
case GGML_UNARY_OP_ABS:
{
ggml_compute_forward_abs(params, src0, dst);
@ -16517,6 +16534,20 @@ void ggml_graph_clear(struct ggml_cgraph * cgraph) {
memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *));
}
void ggml_graph_optimize(struct ggml_cgraph * cgraph) {
for (int i = 1; i < cgraph->n_nodes; ++i) {
struct ggml_tensor * node_current = cgraph->nodes[i-0];
struct ggml_tensor * node_previous = cgraph->nodes[i-1];
if (node_current->op == GGML_OP_UNARY && ggml_get_unary_op(node_current) == GGML_UNARY_OP_SILU
&& node_previous->op == GGML_OP_MUL_MAT) {
ggml_set_op_params_i32(node_previous, 0, ggml_get_op_params_i32(node_current, 0));
ggml_set_op_params_i32(node_current, 0, (int32_t) GGML_UNARY_OP_IDENTITY);
}
}
}
//
// thread data
//
@ -16696,6 +16727,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_IDENTITY:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_NEG:

3
ggml.h
View file

@ -481,6 +481,8 @@ extern "C" {
};
enum ggml_unary_op {
GGML_UNARY_OP_IDENTITY,
GGML_UNARY_OP_ABS,
GGML_UNARY_OP_SGN,
GGML_UNARY_OP_NEG,
@ -1877,6 +1879,7 @@ extern "C" {
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_optimize (struct ggml_cgraph * cgraph);
GGML_API size_t ggml_graph_overhead(void);
GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);

View file

@ -7162,6 +7162,8 @@ static struct ggml_cgraph * llama_build_graph(
llm.free();
ggml_graph_optimize(result);
return result;
}