Merge branch 'master' into compilade/mamba2

This commit is contained in:
Francis Couture-Harpin 2024-10-01 13:09:40 -04:00
commit 7d6cb36895
105 changed files with 8055 additions and 5231 deletions

View file

@ -1,6 +1,6 @@
// This file defines tests for various GGML ops and backends.
// For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.
// For the backwards pass it asserts that the gradients from backpropagation are consistent
// For the backward pass it asserts that the gradients from backpropagation are consistent
// with the gradients obtained via the method of finite differences ("grad" mode, this is optional).
// It is also possible to check the performance ("perf" mode).
//
@ -32,63 +32,52 @@
#include <stdlib.h>
#include <string>
#include <thread>
#include <future>
#include <vector>
static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
// static RNG initialization (revisit if n_threads stops being constant)
static const size_t n_threads = std::thread::hardware_concurrency();
static std::vector<std::default_random_engine> generators = []() {
std::random_device rd;
std::vector<std::default_random_engine> vec;
vec.reserve(n_threads);
//for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
return vec;
}();
size_t nels = ggml_nelements(tensor);
std::vector<float> data(nels);
{
// parallel initialization
static const size_t n_threads = std::thread::hardware_concurrency();
// static RNG initialization (revisit if n_threads stops being constant)
static std::vector<std::default_random_engine> generators = []() {
std::random_device rd;
std::vector<std::default_random_engine> vec;
vec.reserve(n_threads);
//for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
return vec;
}();
size_t size = ggml_nelements(tensor);
std::vector<float> data(size);
auto init_thread = [&](size_t ith, size_t start, size_t end) {
std::uniform_real_distribution<float> distribution(min, max);
auto & gen = generators[ith];
for (size_t i = start; i < end; i++) {
data[i] = distribution(gen);
}
};
auto init_thread = [&](size_t ith, size_t start, size_t end) {
std::uniform_real_distribution<float> distribution(min, max);
for (size_t i = start; i < end; i++) {
data[i] = distribution(generators[ith]);
std::vector<std::future<void>> tasks;
tasks.reserve(n_threads);
for (size_t i = 0; i < n_threads; i++) {
size_t start = i*nels/n_threads;
size_t end = (i+1)*nels/n_threads;
tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
}
};
std::vector<std::thread> threads;
threads.reserve(n_threads);
for (size_t i = 0; i < n_threads; i++) {
size_t start = i*size/n_threads;
size_t end = (i+1)*size/n_threads;
threads.emplace_back(init_thread, i, start, end);
}
for (auto & t : threads) {
t.join();
}
#if 0
const char * val_str = getenv("GGML_TEST_EPS");
float val = 1e-9f;
if (val_str != nullptr) {
val = std::stof(val_str);
printf("GGML_TEST_EPS=%e\n", val);
}
// test quantization with very small values that may result in nan scales due to division by zero
if (ggml_is_quantized(tensor->type)) {
for (int i = 0; i < 256; i++) {
data[i] = val;
for (auto & t : tasks) {
t.get();
}
}
#endif
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float));
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0);
// dummy importance matrix
std::vector<float> imatrix(tensor->ne[0], 1.0f);
const float * im = imatrix.data();
if (!ggml_quantize_requires_imatrix(tensor->type)) {
// when the imatrix is optional, we want to test both quantization with and without imatrix
@ -98,15 +87,31 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
}
}
ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
// TODO: other cases
//#pragma omp parallel for
//for (int i = 0; i < tensor->ne[1]; i++) {
// ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
// i * tensor->ne[0], 1, tensor->ne[0], im);
//}
std::vector<uint8_t> dataq(ggml_row_size(tensor->type, nels));
{
// parallel quantization by block
size_t blck_size = ggml_blck_size(tensor->type);
size_t n_blocks = nels / blck_size;
auto quantize_thread = [&](size_t start, size_t end) {
ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
start * blck_size, end - start, blck_size, im);
};
const size_t min_blocks_per_thread = 1;
const size_t n_threads = std::min<size_t>(std::thread::hardware_concurrency()/2,
std::max<size_t>(1, n_blocks / min_blocks_per_thread));
std::vector<std::future<void>> tasks;
tasks.reserve(n_threads);
for (size_t i = 0; i < n_threads; i++) {
size_t start = i*n_blocks/n_threads;
size_t end = (i+1)*n_blocks/n_threads;
tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
}
for (auto & t : tasks) {
t.get();
}
}
ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
} else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
// This is going to create some weird integers though.
@ -160,60 +165,6 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
return tv;
}
/*
static double cosine_similarity(const float * v1, const float * v2, size_t n) {
double dot = 0.0;
double mag1 = 0.0;
double mag2 = 0.0;
for (size_t i = 0; i < n; i++) {
if (std::isnan(v1[i]) || std::isnan(v2[i])) {
return -1.0f;
}
if (std::isinf(v1[i]) && std::isinf(v2[i])) {
continue;
}
dot += v1[i]*v2[i];
mag1 += v1[i]*v1[i];
mag2 += v2[i]*v2[i];
}
return dot/sqrt(mag1*mag2);
}
static float distance(const float * v1, const float * v2, size_t n) {
double d = 0.0;
for (size_t i = 0; i < n; i++) {
if (std::isnan(v1[i]) || std::isnan(v2[i])) {
return INFINITY;
}
if (std::isinf(v1[i]) && std::isinf(v2[i])) {
continue;
}
d += (v1[i] - v2[i])*(v1[i] - v2[i]);
}
return sqrt(d);
}
static float vec_len(const float * v, size_t n) {
double d = 0.0;
for (size_t i = 0; i < n; i++) {
if (std::isnan(v[i])) {
return INFINITY;
}
if (std::isinf(v[i])) {
continue;
}
d += v[i]*v[i];
}
return sqrt(d);
}
*/
// normalized mean squared error = mse(a, b) / mse(a, 0)
static double nmse(const float * a, const float * b, size_t n) {
double mse_a_b = 0.0;
@ -264,7 +215,6 @@ static double mean_abs_asymm(const float * a, const float * b, const size_t n, c
}
// utils for printing the variables of the test cases
#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
template<typename T>
static std::string var_to_str(const T & x) {
@ -297,10 +247,6 @@ static std::string var_to_str(const std::array<T, N> & x) {
return s;
}
//static std::string var_to_str(ggml_unary_op unary_op) {
// return ggml_unary_op_name(unary_op);
//}
static std::string var_to_str(ggml_type type) {
return ggml_type_name(type);
}
@ -313,6 +259,8 @@ static std::string var_to_str(ggml_op_pool pool) {
}
}
#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
#define VARS_TO_STR1(a) VAR_TO_STR(a)
#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
@ -370,13 +318,13 @@ struct test_case {
return 1e-4;
}
virtual float grad_eps(){
virtual float grad_eps() {
return 1e-1f;
}
// If false, estimate gradient with 2 points, neglects 3rd order derivative and higher.
// If true, estimate gradient with 4 points, neglects 5th order derivative and higher.
virtual bool grad_precise(){
virtual bool grad_precise() {
return false;
}
@ -409,6 +357,11 @@ struct test_case {
return size;
}
virtual uint64_t op_flops(ggml_tensor * t) {
GGML_UNUSED(t);
return 0;
}
ggml_cgraph * gf = nullptr;
ggml_cgraph * gb = nullptr;
@ -651,12 +604,11 @@ struct test_case {
}
// align while also leaving some margin for variations in parameters
int align = 20;
int align = 8;
int last = (len + align - 1) / align * align;
if (last - len < 5) {
last += align;
}
last = std::max(last, 60);
printf("%*s", last - len, "");
// allocate
@ -677,9 +629,25 @@ struct test_case {
// warmup run
ggml_backend_graph_compute(backend, gf);
// determine number of runs
int n_runs;
if (op_flops(out) > 0) {
// based on flops
const uint64_t GFLOP = 1000 * 1000 * 1000;
const uint64_t target_flops_cpu = 8ULL * GFLOP;
const uint64_t target_flops_gpu = 100ULL * GFLOP;
uint64_t target_flops = ggml_backend_is_cpu(backend) ? target_flops_cpu : target_flops_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
} else {
// based on memory size
const size_t GB = 1ULL << 30;
const size_t target_size_cpu = 8 * GB;
const size_t target_size_gpu = 32 * GB;
size_t target_size = ggml_backend_is_cpu(backend) ? target_size_cpu : target_size_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
}
// duplicate the op
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
int n_runs = std::min((size_t) ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
for (int i = 1; i < n_runs; i++) {
ggml_graph_add_node(gf, out);
}
@ -706,17 +674,47 @@ struct test_case {
// run
ggml_backend_synchronize(backend);
int64_t start_time = ggml_time_us();
ggml_backend_graph_compute(backend, gf);
ggml_backend_synchronize(backend);
int64_t end_time = ggml_time_us();
double time_us = end_time - start_time;
int64_t total_time_us = 0;
int total_runs = 0;
do {
int64_t start_time = ggml_time_us();
ggml_backend_graph_compute(backend, gf);
ggml_backend_synchronize(backend);
int64_t end_time = ggml_time_us();
printf(" %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n",
n_runs,
time_us / n_runs,
op_size(out) / 1024,
mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0);
total_time_us += end_time - start_time;
total_runs += n_runs;
} while (total_time_us < 1000*1000); // run for at least 1 second
printf(" %8d runs - %8.2f us/run - ",
total_runs,
(double)total_time_us / total_runs);
if (op_flops(out) > 0) {
double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6);
auto format_flops = [](double flops) -> std::string {
char buf[256];
if (flops >= 1e12) {
snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12);
} else if (flops >= 1e9) {
snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9);
} else if (flops >= 1e6) {
snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6);
} else {
snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3);
}
return buf;
};
printf("%s/run - \033[1;34m%sS\033[0m",
format_flops(op_flops(out)).c_str(),
format_flops(flops_per_sec).c_str());
} else {
printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m",
op_size(out) / 1024,
mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
}
printf("\n");
ggml_backend_buffer_free(buf);
@ -742,7 +740,7 @@ struct test_case {
ggml_tensor * out = build_graph(ctx);
if (op_name != nullptr && op_desc(out) != op_name) {
if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
//printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
@ -751,11 +749,6 @@ struct test_case {
printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);
if (out->grad == nullptr) {
printf("backwards pass not supported \n");
ggml_free(ctx);
return true;
}
if (out->type != GGML_TYPE_F32) {
ggml_free(ctx);
printf("not supported [%s->type != FP32]\n", out->name);
@ -764,18 +757,26 @@ struct test_case {
// check if the backend supports the ops
bool supported = true;
bool any_params = false;
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (!ggml_backend_supports_op(backend, t)) {
printf("not supported [%s] ", ggml_backend_name(backend));
supported = false;
break;
}
if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
printf("not supported [%s->type != FP32] ", t->name);
supported = false;
break;
if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
any_params = true;
if (t->type != GGML_TYPE_F32) {
printf("not supported [%s->type != FP32] ", t->name);
supported = false;
break;
}
}
}
if (!any_params) {
printf("not supported [%s] \n", op_name);
supported = false;
}
if (!supported) {
printf("\n");
ggml_free(ctx);
@ -799,6 +800,7 @@ struct test_case {
out = ggml_sum(ctx, out);
ggml_set_name(out, "sum_of_out");
}
ggml_set_loss(out);
ggml_build_forward_expand(gf, out);
ggml_graph_cpy(gf, gb);
@ -837,22 +839,11 @@ struct test_case {
return false;
}
// randomize tensors
initialize_tensors(ctx);
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
if (!t->grad) {
continue;
}
initialize_tensors(ctx); // Randomizes all tensors (including gradients).
ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise.
std::vector<float> tmp(ggml_nelements(t->grad));
ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad));
}
// build graphs
const float onef = 1.0f;
ggml_backend_graph_compute(backend, gf);
ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad));
ggml_backend_graph_compute(backend, gb);
bool ok = true;
@ -996,7 +987,7 @@ struct test_example : public test_case {
}
// In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
// immediately after you create the tensors.
// This is optional and only makes sense if a backwards pass has actually been implemented for the new op.
// This is optional and only makes sense if a backward pass has actually been implemented for the new op.
};
@ -1235,7 +1226,7 @@ struct test_set : public test_case {
offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
}
ggml_tensor * out = ggml_set(ctx, dst, src,
// The backwards pass requires setting a contiguous region:
// The backward pass requires setting a contiguous region:
src->nb[1], src->nb[2], src->nb[3], offset);
ggml_set_name(out, "out");
@ -1347,7 +1338,7 @@ struct test_bin_bcast : public test_case {
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
// The backwards pass supports broadcasting only for GGML_ADD:
// The backward pass supports broadcasting only for GGML_ADD:
const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
if (grad_supported) {
ggml_set_param(ctx, a);
@ -1584,6 +1575,36 @@ struct test_ssm_scan : public test_case {
}
};
// GGML_OP_RWKV_WKV
struct test_rwkv_wkv : public test_case {
const ggml_type type;
const int64_t head_count;
const int64_t head_size;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
}
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
return out;
}
};
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
@ -1602,13 +1623,9 @@ struct test_mul_mat : public test_case {
return 5e-4;
}
size_t op_size(ggml_tensor * t) override {
size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1];
size_t b = ggml_nbytes(t->src[1]) * m;
size_t c = ggml_nbytes(t);
return a + b + c;
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
}
test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
@ -1652,13 +1669,9 @@ struct test_mul_mat_id : public test_case {
return 5e-4;
}
size_t op_size(ggml_tensor * t) override {
size_t a = ggml_nbytes(t->src[2]) * n;
size_t b = ggml_nbytes(t->src[1]) * m;
size_t c = ggml_nbytes(t);
return a + b + c;
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2 * m * k * n * n_used;
}
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
@ -1712,6 +1725,50 @@ struct test_mul_mat_id : public test_case {
}
};
// GGML_OP_OUT_PROD
struct test_out_prod : public test_case {
const ggml_type type_a;
const ggml_type type_b;
const int64_t m;
const int64_t n;
const int64_t k;
const std::array<int64_t, 2> bs; // dims 3 and 4
const bool trans_b;
std::string vars() override {
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b);
}
double max_nmse_err() override {
return 5e-4;
}
test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10},
bool trans_b = false)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
ggml_set_name(a, "a");
ggml_tensor * b;
if (trans_b) {
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]);
b = ggml_transpose(ctx, b);
} else {
b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]);
}
ggml_set_name(b, "b");
ggml_tensor * out = ggml_out_prod(ctx, a, b);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_SQR
struct test_sqr : public test_case {
const ggml_type type;
@ -1807,7 +1864,7 @@ struct test_log : public test_case {
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// log(1) == 0, cluster values there to keep the sum low for better precision in the backwards pass:
// log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass:
init_tensor_uniform(t, 0.9f, 1.1f);
}
}
@ -2697,6 +2754,54 @@ struct test_cross_entropy_loss : public test_case {
}
};
// GGML_OP_OPT_STEP_ADAMW
struct test_opt_step_adamw : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float alpha;
const float beta1;
const float beta2;
const float eps;
const float wd;
std::string vars() override {
return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd);
}
test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3},
float alpha = 1e-3f,
float beta1 = 0.9f,
float beta2 = 0.999f,
float eps = 1e-8f,
float wd = 0.0f)
: type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
ggml_set_name(a, "a");
ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_name(grad, "grad");
ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values.
}
}
bool grad_precise() override {
return true;
}
};
enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
@ -3085,47 +3190,46 @@ struct test_falcon : public test_llm {
// ###########################################
// ## Section 3: GGML Op Test Instantiation ##
// ###########################################
static const ggml_type all_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
};
static const ggml_type base_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_K,
GGML_TYPE_IQ2_XXS
};
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
static const ggml_type other_types[] = {
GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
GGML_TYPE_BF16,
};
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
std::vector<std::unique_ptr<test_case>> test_cases;
std::default_random_engine rng(0);
const ggml_type all_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
};
const ggml_type base_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_K,
GGML_TYPE_IQ2_XXS
};
const ggml_type other_types[] = {
GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
GGML_TYPE_BF16,
};
// unary ops
for (int v : {0, 1}) {
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
@ -3190,14 +3294,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, 3}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, 3}, {1, 1, 1, 2}));
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
}
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
@ -3289,6 +3394,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
@ -3309,6 +3419,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
}
}
for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
if (ggml_blck_size(type_a) != 256) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1}));
}
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
}
}
#else
// m = a rows
// n = b rows
@ -3328,15 +3446,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
#endif
for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
if (ggml_blck_size(type_a) != 256) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1, 1}, {1, 1}));
}
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
}
}
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));
@ -3382,6 +3491,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}, true));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
}
}
test_cases.emplace_back(new test_sqr());
test_cases.emplace_back(new test_sqrt());
test_cases.emplace_back(new test_log());
@ -3495,7 +3625,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
if (hs != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
}
@ -3508,6 +3638,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
test_cases.emplace_back(new test_cross_entropy_loss());
for (float wd : {0.0f, 1e-2f}) {
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd));
}
// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
@ -3517,20 +3650,30 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_falcon(2));
#endif
// run tests
if (mode == MODE_GRAD) {
size_t n_ok = 0;
for (auto & test : test_cases) {
if (test->eval_grad(backend, op_name)) {
n_ok++;
return test_cases;
}
// Test cases for performance evaluation: should be representative of real-world use cases
static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
std::vector<std::unique_ptr<test_case>> test_cases;
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
for (int bs : {1, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1, 1}, {1, 1}));
}
}
printf(" %zu/%zu tests passed\n", n_ok, test_cases.size());
return n_ok == test_cases.size();
}
return test_cases;
}
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
if (mode == MODE_TEST) {
auto test_cases = make_test_cases_eval();
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
size_t n_ok = 0;
@ -3546,7 +3689,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
return n_ok == test_cases.size();
}
if (mode == MODE_GRAD) {
auto test_cases = make_test_cases_eval();
size_t n_ok = 0;
for (auto & test : test_cases) {
if (test->eval_grad(backend, op_name)) {
n_ok++;
}
}
printf(" %zu/%zu tests passed\n", n_ok, test_cases.size());
return n_ok == test_cases.size();
}
if (mode == MODE_PERF) {
auto test_cases = make_test_cases_perf();
for (auto & test : test_cases) {
test->eval_perf(backend, op_name);
}
@ -3560,9 +3717,9 @@ static void usage(char ** argv) {
printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
printf(" valid modes:\n");
printf(" - test (default, compare with CPU backend for correctness)\n");
printf(" - perf (performance evaluation)\n");
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
printf(" op names are as given by ggml_op_desc() (e.g. GGML_ADD)\n");
printf(" - perf (performance evaluation)\n");
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
}
int main(int argc, char ** argv) {
@ -3621,6 +3778,11 @@ int main(int argc, char ** argv) {
continue;
}
if (ggml_backend_is_cpu(backend)) {
// TODO: better value for n_threads
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
}
printf(" Backend name: %s\n", ggml_backend_name(backend));
bool ok = test_backend(backend, mode, op_name_filter);