tune: extract ggml_mulmat_tune_bench_wrapper

This commit is contained in:
mqy 2023-06-19 13:54:20 +08:00
parent 65fd65e0c1
commit 44b831dc59
3 changed files with 54 additions and 51 deletions

View file

@ -935,3 +935,48 @@ bool ggml_mulmat_tune_bench(struct ggml_mulmat_tune *tune,
return true; return true;
} }
bool ggml_mulmat_tune_bench_wrapper(struct ggml_mulmat_tune *mulmat_tune,
struct ggml_mulmat_tune_params *params,
bool run_bench) {
printf("\n");
bool empty_fname = !params->fname || strcmp(params->fname, "") == 0;
if (!ggml_cpu_has_blas()) {
fprintf(stderr, "[tune] this program is not built with BLAS, abort.\n");
return 1;
}
if (run_bench) {
return ggml_mulmat_tune_bench(mulmat_tune, params);
}
if (!empty_fname) {
FILE *fp = fopen(params->fname, "r");
if (!fp) {
fprintf(stderr, "[tune] failed to open file %s.\n", params->fname);
return false;
} else {
int rc = ggml_mulmat_tune_read_data(mulmat_tune, fp);
fclose(fp);
if (rc != 0) {
fprintf(stderr,
"[tune] failed to read data from %s, error code: %d\n",
params->fname, rc);
return false;
}
fprintf(stderr, "[tune] loaded data from %s\n", params->fname);
bool ok = ggml_mulmat_tune_validate(mulmat_tune, mulmat_tune->model,
params->model.ftype,
params->n_threads);
if (!ok) {
return false;
}
}
}
return true;
}

View file

@ -132,6 +132,12 @@ void ggml_mulmat_tune_estimate_time(const struct ggml_mulmat_tune_shape *shape,
bool ggml_mulmat_tune_bench(struct ggml_mulmat_tune *tune, bool ggml_mulmat_tune_bench(struct ggml_mulmat_tune *tune,
struct ggml_mulmat_tune_params *params); struct ggml_mulmat_tune_params *params);
// This API is intended to be called by llama, etc.
// Three modes: bench and run; bench(save) then exit; load and run
bool ggml_mulmat_tune_bench_wrapper(struct ggml_mulmat_tune *mulmat_tune,
struct ggml_mulmat_tune_params *params,
bool run_bench);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -2748,8 +2748,6 @@ bool llama_mulmat_tune(struct llama_context *ctx, int n_threads, bool tune,
const char *fname) { const char *fname) {
GGML_ASSERT(ctx->model.n_gpu_layers == 0); GGML_ASSERT(ctx->model.n_gpu_layers == 0);
printf("\n");
const char *model_name = llama_model_type_name(ctx->model.type); const char *model_name = llama_model_type_name(ctx->model.type);
llama_hparams *hparams = &ctx->model.hparams; llama_hparams *hparams = &ctx->model.hparams;
@ -2820,71 +2818,25 @@ bool llama_mulmat_tune(struct llama_context *ctx, int n_threads, bool tune,
/* .m_num =*/8, /* .m_num =*/8,
/* .n_pass =*/1, /* .n_pass =*/1,
/* .n_threads =*/n_threads, /* .n_threads =*/n_threads,
/* .prrogress =*/true, /* .progress =*/true,
/* .output_console =*/false, /* .output_console =*/false,
/* .fname =*/fname, /* .fname =*/fname,
}; };
bool empty_fname = !fname || strcmp(fname, "") == 0;
ctx->tune = new (struct ggml_mulmat_tune); ctx->tune = new (struct ggml_mulmat_tune);
if (!ctx->tune) { if (!ctx->tune) {
fprintf(stderr, "[tune] failed to allocate memory for tune\n"); fprintf(stderr, "[tune] failed to allocate memory for tune\n");
return false; return false;
} }
if (!ggml_cpu_has_blas()) { return ggml_mulmat_tune_bench_wrapper(ctx->tune, &params, tune);
fprintf(stderr, "[tune] this program is not built with BLAS, abort.\n");
return false;
}
if (tune) {
bool ok = ggml_mulmat_tune_bench(ctx->tune, &params);
if (!ok) {
ggml_mulmat_tune_free(ctx->tune);
return false;
}
if (!empty_fname) {
ggml_mulmat_tune_free(ctx->tune);
return true;
}
} else if (empty_fname) {
return false;
}
if (!empty_fname) {
FILE *fp = fopen(fname, "r");
if (!fp) {
fprintf(stderr, "[tune] failed to open file %s.\n", fname);
return false;
} else {
int rc = ggml_mulmat_tune_read_data(ctx->tune, fp);
fclose(fp);
if (rc != 0) {
fprintf(stderr,
"[tune] failed to read data from %s, error code: %d\n",
fname, rc);
return false;
}
fprintf(stderr, "[tune] loaded data from %s\n", fname);
bool ok = ggml_mulmat_tune_validate(ctx->tune, model_name, ggml_ftype,
params.n_threads);
if (!ok) {
return false;
}
}
}
return true;
} }
#endif #endif
void llama_free(struct llama_context * ctx) { void llama_free(struct llama_context * ctx) {
#ifdef GGML_USE_TUNE #ifdef GGML_USE_TUNE
if (ctx->tune) { if (ctx->tune) {
ggml_mulmat_tune_free(ctx->tune);
delete(ctx->tune); delete(ctx->tune);
} }
#endif #endif