Passing around the state

This commit is contained in:
Kunnis 2024-05-08 23:10:50 -05:00
parent 5978b6ebf0
commit e098171aa7

19
ggml.c
View file

@ -11771,7 +11771,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
struct ggml_tensor * dst,
struct ggml_compute_state* state) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
@ -17481,7 +17482,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
GGML_ASSERT(params);
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@ -17579,7 +17580,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor);
ggml_compute_forward_mul_mat(params, tensor, state);
} break;
case GGML_OP_MUL_MAT_ID:
{
@ -19639,7 +19640,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
struct ggml_tensor * node = cgraph->nodes[node_n];
if (GGML_OP_HAS_FINALIZE[node->op]) {
params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
ggml_compute_forward(&params, node);
ggml_compute_forward(&params, node, state);
}
ggml_graph_compute_perf_stats_node(node, state->shared);
}
@ -19659,17 +19660,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/* INIT */
if (GGML_OP_HAS_INIT[node->op]) {
params.type = GGML_TASK_TYPE_INIT;
ggml_compute_forward(&params, node);
ggml_compute_forward(&params, node, state);
}
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
// they do something more efficient than spinning (?)
params.type = GGML_TASK_TYPE_COMPUTE;
ggml_compute_forward(&params, node);
ggml_compute_forward(&params, node, state);
if (GGML_OP_HAS_FINALIZE[node->op]) {
params.type = GGML_TASK_TYPE_FINALIZE;
ggml_compute_forward(&params, node);
ggml_compute_forward(&params, node, state);
}
ggml_graph_compute_perf_stats_node(node, state->shared);
@ -19708,7 +19709,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (state->ith < n_tasks) {
if (GGML_OP_HAS_INIT[node->op]) {
ggml_compute_forward(&params, node);
ggml_compute_forward(&params, node, state);
}
}
@ -19729,7 +19730,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
if (state->ith < n_tasks) {
params.type = GGML_TASK_TYPE_COMPUTE;
ggml_compute_forward(&params, node);
ggml_compute_forward(&params, node, state);
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {