ggml : sync (abort callback, mul / add broadcast, fix alibi) (#2183)

This commit is contained in:
Georgi Gerganov 2023-07-11 22:53:34 +03:00 committed by GitHub
parent 5bf2a27718
commit 20d7740a9b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 72 deletions

115
ggml.c
View file

@ -25,6 +25,7 @@
#include <float.h>
#include <limits.h>
#include <stdarg.h>
#include <signal.h>
#ifdef GGML_USE_METAL
#include <unistd.h>
@ -49,23 +50,23 @@
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
static void atomic_store(atomic_int* ptr, LONG val) {
static void atomic_store(atomic_int * ptr, LONG val) {
InterlockedExchange(ptr, val);
}
static LONG atomic_load(atomic_int* ptr) {
static LONG atomic_load(atomic_int * ptr) {
return InterlockedCompareExchange(ptr, 0, 0);
}
static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
return InterlockedExchangeAdd(ptr, inc);
}
static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}
typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
(void) unused;
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL)
@ -77,7 +78,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
return 0;
}
static int pthread_join(pthread_t thread, void* unused) {
static int pthread_join(pthread_t thread, void * unused) {
(void) unused;
return (int) WaitForSingleObject(thread, INFINITE);
}
@ -90,7 +91,7 @@ static int sched_yield (void) {
#include <pthread.h>
#include <stdatomic.h>
typedef void* thread_ret_t;
typedef void * thread_ret_t;
#include <sys/types.h>
#include <sys/stat.h>
@ -4723,7 +4724,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{
assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) {
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
}
} break;
case GGML_TYPE_F32:
@ -4775,7 +4776,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{
assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) {
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
}
} break;
case GGML_TYPE_F32:
@ -5035,11 +5036,15 @@ struct ggml_tensor * ggml_add_impl(
struct ggml_tensor * a,
struct ggml_tensor * b,
bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b));
// TODO: support less-strict constraint
// GGML_ASSERT(ggml_can_repeat(b, a));
GGML_ASSERT(ggml_can_repeat_rows(b, a));
bool is_node = false;
if (a->grad || b->grad) {
if (!inplace && (a->grad || b->grad)) {
// TODO: support backward pass for broadcasting
GGML_ASSERT(ggml_are_same_shape(a, b));
is_node = true;
}
@ -8297,7 +8302,7 @@ static void ggml_compute_forward_add_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
@ -8322,23 +8327,23 @@ static void ggml_compute_forward_add_f32(
if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
#ifdef GGML_USE_ACCELERATE
vDSP_vadd(
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0);
vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
#else
ggml_vec_add_f32(ne0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
#endif
// }
// }
@ -8346,15 +8351,20 @@ static void ggml_compute_forward_add_f32(
} else {
// src1 is not contiguous
for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i0 = 0; i0 < ne0; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
}
@ -11717,7 +11727,7 @@ static void ggml_compute_forward_alibi_f32(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
//const int ne2 = src0->ne[2]; // n_head -> this is k
const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0);
@ -11728,8 +11738,9 @@ static void ggml_compute_forward_alibi_f32(
const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
assert(nb0 == sizeof(float));
assert(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(ne1 + n_past == ne0);
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@ -11753,7 +11764,7 @@ static void ggml_compute_forward_alibi_f32(
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}
pdst[0] = (i-ne0+1) * m_k + src[0];
pdst[0] = i * m_k + src[0];
}
}
@ -11782,7 +11793,7 @@ static void ggml_compute_forward_alibi_f16(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
//const int ne2 = src0->ne[2]; // n_head -> this is k
const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0);
@ -11793,8 +11804,9 @@ static void ggml_compute_forward_alibi_f16(
const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
assert(nb0 == sizeof(ggml_fp16_t));
assert(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@ -11819,7 +11831,7 @@ static void ggml_compute_forward_alibi_f16(
}
// we return F32
pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
}
}
}
@ -15944,6 +15956,9 @@ struct ggml_compute_state_shared {
// synchronization primitives
atomic_int n_active; // num active threads
atomic_int node_n; // active graph node
bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
void * abort_callback_data;
};
struct ggml_compute_state {
@ -15975,6 +15990,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
int node_n = -1;
while (true) {
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->node_n += 1;
return (thread_ret_t) GGML_EXIT_ABORTED;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again
@ -16028,6 +16047,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} else {
break;
}
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
break;
}
}
atomic_store(&state->shared->n_active, n_threads);
@ -16061,7 +16084,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}
return 0;
return GGML_EXIT_SUCCESS;
}
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
@ -16401,7 +16424,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
return cplan;
}
void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0);
@ -16427,6 +16450,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
/*.n_threads =*/ n_threads,
/*.n_active =*/ n_threads,
/*.node_n =*/ -1,
/*.abort_callback =*/ NULL,
/*.abort_callback_data =*/ NULL,
};
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
@ -16450,12 +16475,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
const int64_t perf_start_time_us = ggml_perf_time_us();
// this is a work thread too
ggml_graph_compute_thread(&workers[0]);
int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
// don't leave affinity set on the main thread
clear_numa_thread_affinity();
// join thread pool
// join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
@ -16479,6 +16504,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
(double) perf_time_us_cur / 1000.0,
(double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
}
return compute_status;
}
void ggml_graph_reset(struct ggml_cgraph * cgraph) {