tests : remove benchmarks

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-19 17:38:28 +03:00
parent 29f6ad8d95
commit 52945429eb
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -15,6 +15,9 @@
#include <thread>
#include <vector>
// TODO: remove before merging
//#define TMP_ATTN_BENCH
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
// static RNG initialization (revisit if n_threads stops being constant)
static const size_t n_threads = std::thread::hardware_concurrency();
@ -571,7 +574,7 @@ struct test_case {
// duplicate the op
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
#if 0
#ifndef TMP_ATTN_BENCH
for (int i = 1; i < n_runs; i++) {
gf->nodes[gf->n_nodes++] = out;
}
@ -1513,8 +1516,8 @@ struct test_flash_attn_ext : public test_case {
}
};
#ifdef TMP_ATTN_BENCH
// ATTN
// TODO: this is temporary until the FA branch is merged
struct test_attn : public test_case {
const int64_t hs; // head size
const int64_t nh; // num heads
@ -1555,6 +1558,7 @@ struct test_attn : public test_case {
return cur;
}
};
#endif
enum llm_norm_type {
LLM_NORM,
@ -2220,7 +2224,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
#if 1
#ifdef TMP_ATTN_BENCH
for (int hs : { 128, 256, 64, 80, }) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, 2048, 4096, }) {
@ -2232,11 +2236,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
#else
for (int hs : { 128, }) {
for (int hs : { 64, 80, 128, 256, }) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, 512 }) {
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
for (int nb : { 1, 2, 4, 8, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
}
}