apply suggestions

This commit is contained in:
FSSRepo 2024-01-19 20:18:18 -05:00
parent 09db1a7cf3
commit fded2e6a11
2 changed files with 83 additions and 34 deletions

View file

@ -23,8 +23,9 @@ struct test_model {
struct ggml_tensor * k;
struct ggml_tensor * v;
ggml_backend_t backend = NULL;
ggml_backend_buffer_t buffer;
struct ggml_context * ctx;
ggml_backend_buffer_t buffer = NULL;
struct ggml_context * ctx = NULL;
bool naive_attn = false;
};
static std::vector<float> tensor_to_float(const ggml_tensor * t) {
@ -216,8 +217,16 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false);
ggml_build_forward_expand(gf, result);
if(!model.naive_attn) {
struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false);
ggml_build_forward_expand(gf, result);
} else {
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0]));
kq = ggml_soft_max(ctx0, kq);
kq = ggml_mul_mat(ctx0, model.v, kq);
ggml_build_forward_expand(gf, kq);
}
// delete the temporally context used to build the graph
ggml_free(ctx0);
@ -330,15 +339,18 @@ struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backe
int main(int argc, char ** argv)
{
bool compare_backend = false;
test_model model;
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "comp") == 0) {
compare_backend = true;
} else if (strcmp(argv[i], "naive") == 0) {
model.naive_attn = true;
}
}
ggml_time_init();
test_model model;
load_model(model, true);
ggml_backend_buffer_t buf_compute; // for compute
@ -359,9 +371,11 @@ int main(int argc, char ** argv)
}
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
uint64_t compute_time_us__ = ggml_time_us();
struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend);
if(!compare_backend) {
ggml_backend_synchronize(model.backend);
printf("computing time: %.4f ms\n", (ggml_time_us() - compute_time_us__) / 1000.0);
float* data = new float[ggml_nelements(result)];
ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result));