cuda: add flash attention + test
This commit is contained in:
parent
4f4bf35f46
commit
f7bcfb0566
3 changed files with 526 additions and 1 deletions
140
ggml-cuda.cu
140
ggml-cuda.cu
|
@ -5987,6 +5987,88 @@ static __global__ void im2col_f32_f16(
|
|||
}
|
||||
}
|
||||
|
||||
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
||||
|
||||
template<int block_size>
|
||||
static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale,
|
||||
int d_head, int seq_len, int num_heads) {
|
||||
const int head = blockIdx.x / seq_len;
|
||||
const int head_size = d_head * seq_len;
|
||||
const int s = blockIdx.x % seq_len;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
extern __shared__ char work_data[];
|
||||
float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit
|
||||
float* warp_data = (float*)(work_data + seq_len * sizeof(float));
|
||||
|
||||
// QK^T
|
||||
for(int is = tid; is < seq_len; is += block_size) {
|
||||
S[is] = 0.0f;
|
||||
int key_offset = is * d_head + head * head_size;
|
||||
int query_offset = s * d_head + head * head_size;
|
||||
for(int d = 0; d < d_head; d++) {
|
||||
S[is] += k[key_offset + d] * q[query_offset + d];
|
||||
}
|
||||
S[is] *= kq_scale;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float max_val = -INFINITY;
|
||||
// get the max
|
||||
for(int is = tid; is < seq_len; is += block_size) {
|
||||
max_val = fmaxf(max_val , S[is]);
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
{ // get max from all threads
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
warp_data[warp_id] = max_val;
|
||||
}
|
||||
__syncthreads();
|
||||
max_val = warp_data[lane_id];
|
||||
max_val = warp_reduce_max(max_val);
|
||||
}
|
||||
|
||||
// softmax(QK^T)
|
||||
float sum = 0.0f;
|
||||
for(int is = tid; is < seq_len;is += block_size) {
|
||||
const float val = expf(S[is] - max_val);
|
||||
S[is] = val;
|
||||
sum += val;
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
{ // sum partials
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
warp_data[warp_id] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
sum = warp_data[lane_id];
|
||||
sum = warp_reduce_sum(sum);
|
||||
}
|
||||
|
||||
float inv_sum = 1.0f / sum;
|
||||
for(int is = tid; is < seq_len; is += block_size) {
|
||||
S[is] *= inv_sum;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// softmax(QK^T)V
|
||||
for (int d = tid; d < d_head; d += block_size) {
|
||||
int dst_index = d + s * d_head + head * head_size;
|
||||
int value_offset = d * seq_len + head * head_size;
|
||||
dst[dst_index] = 0.0f;
|
||||
for(int ic = 0; ic < seq_len; ic++) {
|
||||
dst[dst_index] += v[value_offset + ic] * S[ic];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
@ -7377,6 +7459,13 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
|
|||
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||
}
|
||||
|
||||
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
||||
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
||||
int num_blocks = num_heads * seq_len;
|
||||
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
||||
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 256
|
||||
|
||||
|
@ -9900,6 +9989,51 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
}
|
||||
}
|
||||
|
||||
inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) {
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||
|
||||
const int64_t d_head = Q->ne[0];
|
||||
const int64_t sequence_length = Q->ne[1];
|
||||
const int64_t num_heads = Q->ne[2];
|
||||
|
||||
GGML_ASSERT(Q->ne[0] == d_head);
|
||||
GGML_ASSERT(K->ne[0] == d_head);
|
||||
GGML_ASSERT(V->ne[1] == d_head);
|
||||
|
||||
GGML_ASSERT(Q->ne[1] == sequence_length);
|
||||
GGML_ASSERT(K->ne[1] == sequence_length);
|
||||
GGML_ASSERT(V->ne[0] == sequence_length);
|
||||
|
||||
GGML_ASSERT(Q->ne[2] == num_heads);
|
||||
GGML_ASSERT(K->ne[2] == num_heads);
|
||||
GGML_ASSERT(V->ne[2] == num_heads);
|
||||
|
||||
float KQ_scale = 1.0f / sqrtf((float)d_head);
|
||||
|
||||
flash_attn_f32_cuda(
|
||||
(float *) src0_extra->data_device[g_main_device], // Query
|
||||
(float *) src1_extra->data_device[g_main_device], // Key
|
||||
(float *) src2_extra->data_device[g_main_device], // Value
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
KQ_scale, d_head, sequence_length, num_heads, main_stream);
|
||||
}
|
||||
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
||||
}
|
||||
|
@ -10168,6 +10302,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||
case GGML_OP_ARGSORT:
|
||||
func = ggml_cuda_argsort;
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
@ -10182,7 +10318,11 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return true;
|
||||
}
|
||||
if(tensor->op == GGML_OP_FLASH_ATTN) {
|
||||
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||
} else {
|
||||
func(tensor->src[0], tensor->src[1], tensor);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,8 @@ llama_build_and_test_executable(test-backend-ops.cpp)
|
|||
|
||||
llama_build_and_test_executable(test-rope.cpp)
|
||||
|
||||
llama_build_executable(test-flash-attention.cpp)
|
||||
|
||||
# dummy executable - not installed
|
||||
get_filename_component(TEST_TARGET test-c.c NAME_WE)
|
||||
add_executable(${TEST_TARGET} test-c.c)
|
||||
|
|
383
tests/test-flash-attention.cpp
Normal file
383
tests/test-flash-attention.cpp
Normal file
|
@ -0,0 +1,383 @@
|
|||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#define GGML_USE_CUBLAS
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct test_model {
|
||||
struct ggml_tensor * q;
|
||||
struct ggml_tensor * k;
|
||||
struct ggml_tensor * v;
|
||||
ggml_backend_t backend = NULL;
|
||||
ggml_backend_buffer_t buffer;
|
||||
struct ggml_context * ctx;
|
||||
};
|
||||
|
||||
static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
||||
std::vector<float> tv;
|
||||
tv.reserve(ggml_nelements(t));
|
||||
|
||||
std::vector<uint8_t> buf(ggml_nbytes(t));
|
||||
ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
|
||||
|
||||
ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
|
||||
size_t bs = ggml_blck_size(t->type);
|
||||
std::vector<float> vq(ggml_blck_size(t->type));
|
||||
bool quantized = ggml_is_quantized(t->type);
|
||||
|
||||
// access elements by index to avoid gaps in views
|
||||
for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
|
||||
for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
|
||||
for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
|
||||
for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
|
||||
size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
|
||||
if (t->type == GGML_TYPE_F16) {
|
||||
tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
|
||||
} else if (t->type == GGML_TYPE_F32) {
|
||||
tv.push_back(*(float *) &buf[i]);
|
||||
} else if (t->type == GGML_TYPE_I32) {
|
||||
tv.push_back((float)*(int32_t *) &buf[i]);
|
||||
} else if (t->type == GGML_TYPE_I16) {
|
||||
tv.push_back((float)*(int16_t *) &buf[i]);
|
||||
} else if (t->type == GGML_TYPE_I8) {
|
||||
tv.push_back((float)*(int8_t *) &buf[i]);
|
||||
} else if (quantized) {
|
||||
std::vector<float> vq(ggml_blck_size(t->type));
|
||||
tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
|
||||
tv.insert(tv.end(), vq.begin(), vq.end());
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tv;
|
||||
}
|
||||
|
||||
// accept FLT_MAX as infinity
|
||||
static bool isinf_or_max(float f) {
|
||||
return std::isinf(f) || f == FLT_MAX || f == -FLT_MAX;
|
||||
}
|
||||
|
||||
// normalized mean squared error = mse(a, b) / mse(a, 0)
|
||||
static double nmse(const float * a, const float * b, size_t n) {
|
||||
double mse_a_b = 0.0;
|
||||
double mse_a_0 = 0.0;
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
float a_i = a[i];
|
||||
float b_i = b[i];
|
||||
|
||||
mse_a_b += (a_i - b_i) * (a_i - b_i);
|
||||
mse_a_0 += a_i * a_i;
|
||||
}
|
||||
|
||||
return mse_a_b / mse_a_0;
|
||||
}
|
||||
|
||||
void ggml_tensor_set_f32(struct ggml_tensor* tensor, float value, int l, int k = 0, int j = 0, int i = 0) {
|
||||
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
||||
*(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]) = value;
|
||||
}
|
||||
|
||||
float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) {
|
||||
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
||||
return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]);
|
||||
}
|
||||
|
||||
void load_model(test_model & model, bool use_gpu = false) {
|
||||
float Query[30] = { // [3, 4, 2]
|
||||
// z0
|
||||
2, 4, 2,
|
||||
4, 2, 1,
|
||||
4, 1, 3,
|
||||
4, 2, 2,
|
||||
|
||||
// z1
|
||||
2, 1, 1,
|
||||
4, 2, 1,
|
||||
1, 1, 3,
|
||||
4, 2, 1
|
||||
};
|
||||
|
||||
float Key[24] = { // [3, 4, 2]
|
||||
// z0
|
||||
2, 4, 2,
|
||||
4, 2, 1,
|
||||
4, 2, 3,
|
||||
1, 2, 1,
|
||||
|
||||
// z1
|
||||
3, 1, 3,
|
||||
4, 2, 1,
|
||||
1, 1, 2,
|
||||
4, 3, 1
|
||||
};
|
||||
|
||||
float Value[24] = { // [4, 3, 2]
|
||||
// z0
|
||||
2, 4, 2, 1,
|
||||
2, 1, 4, 2,
|
||||
1, 4, 2, 3,
|
||||
|
||||
// z1
|
||||
1, 4, 2, 1,
|
||||
2, 1, 1, 2,
|
||||
1, 4, 3, 3,
|
||||
};
|
||||
|
||||
size_t buffer_size = 0;
|
||||
{
|
||||
buffer_size += 30 * ggml_type_sizef(GGML_TYPE_F32); // tensor q
|
||||
buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor k
|
||||
buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor v
|
||||
buffer_size += 1024;
|
||||
}
|
||||
|
||||
printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
|
||||
printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f));
|
||||
|
||||
int num_tensors = 3;
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ ggml_tensor_overhead() * num_tensors,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
// initialize the backend
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (use_gpu) {
|
||||
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
model.backend = ggml_backend_cuda_init(0);
|
||||
if (!model.backend) {
|
||||
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!model.backend) {
|
||||
// fallback to CPU backend
|
||||
model.backend = ggml_backend_cpu_init();
|
||||
}
|
||||
|
||||
model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);
|
||||
|
||||
// create context
|
||||
model.ctx = ggml_init(params);
|
||||
|
||||
// create tensors
|
||||
model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2);
|
||||
model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2);
|
||||
model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 4, 3, 2);
|
||||
|
||||
// create a allocator
|
||||
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
|
||||
|
||||
// alloc memory
|
||||
ggml_allocr_alloc(alloc, model.q);
|
||||
ggml_allocr_alloc(alloc, model.k);
|
||||
ggml_allocr_alloc(alloc, model.v);
|
||||
|
||||
ggml_backend_tensor_set(model.q, Query, 0, ggml_nbytes(model.q));
|
||||
ggml_backend_tensor_set(model.k, Key, 0, ggml_nbytes(model.k));
|
||||
ggml_backend_tensor_set(model.v, Value, 0, ggml_nbytes(model.v));
|
||||
|
||||
ggml_allocr_free(alloc);
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) {
|
||||
static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
|
||||
static std::vector<uint8_t> buf(buf_size);
|
||||
|
||||
struct ggml_init_params params0 = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf.data(),
|
||||
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
|
||||
};
|
||||
|
||||
// create a temporally context to build the graph
|
||||
struct ggml_context * ctx0 = ggml_init(params0);
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false);
|
||||
ggml_build_forward_expand(gf, result);
|
||||
|
||||
// delete the temporally context used to build the graph
|
||||
ggml_free(ctx0);
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backend_cpu, struct ggml_allocr * allocr, bool compare_backends) {
|
||||
// reset the allocator to free all the memory allocated during the previous inference
|
||||
ggml_allocr_reset(allocr);
|
||||
|
||||
struct ggml_cgraph * gf = build_graph(model, allocr);
|
||||
|
||||
// allocate tensors
|
||||
ggml_allocr_alloc_graph(allocr, gf);
|
||||
int n_threads = 1;
|
||||
|
||||
if (ggml_backend_is_cpu(model.backend)) {
|
||||
ggml_backend_cpu_set_n_threads(model.backend, n_threads);
|
||||
}
|
||||
|
||||
if(!compare_backends) {
|
||||
ggml_backend_graph_compute(model.backend, gf);
|
||||
|
||||
// in this case, the output tensor is the last one in the graph
|
||||
return gf->nodes[gf->n_nodes - 1];
|
||||
}
|
||||
|
||||
struct callback_userdata {
|
||||
bool ok;
|
||||
double max_err;
|
||||
ggml_backend_t backend1;
|
||||
ggml_backend_t backend2;
|
||||
};
|
||||
|
||||
callback_userdata ud {
|
||||
true,
|
||||
1e-7,
|
||||
model.backend,
|
||||
backend_cpu
|
||||
};
|
||||
|
||||
auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
|
||||
callback_userdata * ud = (callback_userdata *) user_data;
|
||||
const char * bn1 = ggml_backend_name(ud->backend1);
|
||||
const char * bn2 = ggml_backend_name(ud->backend2);
|
||||
|
||||
if (t1->op == GGML_OP_NONE) {
|
||||
// sentinels must be unchanged
|
||||
std::vector<uint8_t> t1_data(ggml_nbytes(t1));
|
||||
std::vector<uint8_t> t2_data(ggml_nbytes(t2));
|
||||
ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));
|
||||
ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));
|
||||
|
||||
if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
|
||||
printf("sentinel mismatch: %s ", t1->name);
|
||||
ud->ok = false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> f1 = tensor_to_float(t1);
|
||||
std::vector<float> f2 = tensor_to_float(t2);
|
||||
|
||||
for (size_t i = 0; i < f1.size(); i++) {
|
||||
// check for nans
|
||||
if (std::isnan(f1[i]) || std::isnan(f2[i])) {
|
||||
printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]);
|
||||
ud->ok = false;
|
||||
return true;
|
||||
}
|
||||
// check for infs: both must be inf of the same sign, or both must be finite
|
||||
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
|
||||
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
|
||||
if (std::signbit(f1[i]) != std::signbit(f2[i])) {
|
||||
printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
|
||||
ud->ok = false;
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
|
||||
ud->ok = false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||
if (err > ud->max_err) {
|
||||
printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
|
||||
ud->ok = false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
GGML_UNUSED(index);
|
||||
};
|
||||
|
||||
printf("\nTesting Flash Attention - comparing backends: ");
|
||||
|
||||
const bool cmp_ok = ggml_backend_compare_graph_backend(model.backend, backend_cpu, gf, callback, &ud);
|
||||
if (ud.ok && cmp_ok) {
|
||||
printf("\033[1;32mOK\033[0m\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
printf("\033[1;31mFAIL\033[0m\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv)
|
||||
{
|
||||
bool compare_backend = false;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
if (strcmp(argv[i], "comp") == 0) {
|
||||
compare_backend = true;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_time_init();
|
||||
|
||||
test_model model;
|
||||
load_model(model, true);
|
||||
|
||||
ggml_backend_buffer_t buf_compute; // for compute
|
||||
struct ggml_allocr * allocr = NULL;
|
||||
|
||||
{
|
||||
allocr = ggml_allocr_new_measure_from_backend(model.backend);
|
||||
|
||||
//create the worst case graph for memory usage estimation
|
||||
struct ggml_cgraph * gf = build_graph(model, allocr);
|
||||
size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);
|
||||
ggml_allocr_free(allocr);
|
||||
|
||||
// compute the required memory
|
||||
buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
|
||||
allocr = ggml_allocr_new_from_buffer(buf_compute);
|
||||
fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
|
||||
}
|
||||
|
||||
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
||||
|
||||
struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend);
|
||||
if(!compare_backend) {
|
||||
float* data = new float[ggml_nelements(result)];
|
||||
|
||||
ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result));
|
||||
printf("\nPerforming test:\n");
|
||||
|
||||
for(int i = 0; i < ggml_nelements(result); i ++) {
|
||||
if(i > 0 && (i % result->ne[0] == 0)) {
|
||||
printf("\n");
|
||||
}
|
||||
printf("%2.6f ", data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_free(model.ctx);
|
||||
|
||||
ggml_backend_buffer_free(model.buffer);
|
||||
ggml_backend_buffer_free(buf_compute);
|
||||
ggml_backend_free(model.backend);
|
||||
return 0;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue