Fuse matrix multiplication + SiLU
This commit is contained in:
parent
a6e514a85f
commit
3cf123e8f6
3 changed files with 39 additions and 2 deletions
36
ggml.c
36
ggml.c
|
@ -1817,6 +1817,7 @@ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
|||
|
||||
|
||||
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"IDENTITY",
|
||||
"ABS",
|
||||
"SGN",
|
||||
"NEG",
|
||||
|
@ -1831,7 +1832,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
|||
"HARDSIGMOID",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
|
||||
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
@ -9992,6 +9993,7 @@ static void ggml_compute_forward_mul_mat(
|
|||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
enum ggml_unary_op const activation = (enum ggml_unary_op) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
|
@ -10197,7 +10199,20 @@ static void ggml_compute_forward_mul_mat(
|
|||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
|
||||
}
|
||||
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
|
||||
float * dst_ptr = &dst_col[iir0];
|
||||
const int64_t n = MIN(iir0 + blck_0, ir011) - iir0;
|
||||
switch (activation) {
|
||||
case GGML_UNARY_OP_IDENTITY:
|
||||
memcpy(dst_ptr, tmp, n*sizeof(float));
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
ggml_vec_silu_f32(n, dst_ptr, tmp);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14220,6 +14235,8 @@ static void ggml_compute_forward_unary(
|
|||
const enum ggml_unary_op op = ggml_get_unary_op(dst);
|
||||
|
||||
switch (op) {
|
||||
case GGML_UNARY_OP_IDENTITY:
|
||||
break; // nothing to do
|
||||
case GGML_UNARY_OP_ABS:
|
||||
{
|
||||
ggml_compute_forward_abs(params, src0, dst);
|
||||
|
@ -16517,6 +16534,20 @@ void ggml_graph_clear(struct ggml_cgraph * cgraph) {
|
|||
memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *));
|
||||
}
|
||||
|
||||
void ggml_graph_optimize(struct ggml_cgraph * cgraph) {
|
||||
for (int i = 1; i < cgraph->n_nodes; ++i) {
|
||||
struct ggml_tensor * node_current = cgraph->nodes[i-0];
|
||||
struct ggml_tensor * node_previous = cgraph->nodes[i-1];
|
||||
|
||||
if (node_current->op == GGML_OP_UNARY && ggml_get_unary_op(node_current) == GGML_UNARY_OP_SILU
|
||||
&& node_previous->op == GGML_OP_MUL_MAT) {
|
||||
|
||||
ggml_set_op_params_i32(node_previous, 0, ggml_get_op_params_i32(node_current, 0));
|
||||
ggml_set_op_params_i32(node_current, 0, (int32_t) GGML_UNARY_OP_IDENTITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// thread data
|
||||
//
|
||||
|
@ -16696,6 +16727,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(node)) {
|
||||
case GGML_UNARY_OP_IDENTITY:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_NEG:
|
||||
|
|
3
ggml.h
3
ggml.h
|
@ -481,6 +481,8 @@ extern "C" {
|
|||
};
|
||||
|
||||
enum ggml_unary_op {
|
||||
GGML_UNARY_OP_IDENTITY,
|
||||
|
||||
GGML_UNARY_OP_ABS,
|
||||
GGML_UNARY_OP_SGN,
|
||||
GGML_UNARY_OP_NEG,
|
||||
|
@ -1877,6 +1879,7 @@ extern "C" {
|
|||
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
|
||||
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
|
||||
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
|
||||
GGML_API void ggml_graph_optimize (struct ggml_cgraph * cgraph);
|
||||
|
||||
GGML_API size_t ggml_graph_overhead(void);
|
||||
GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
|
||||
|
|
|
@ -7162,6 +7162,8 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
|
||||
llm.free();
|
||||
|
||||
ggml_graph_optimize(result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue