Reenable BLAS for quantized mul_mat

This commit is contained in:
Georgi Gerganov 2023-03-24 22:03:56 +02:00
parent ea60d2193a
commit 3634c312bc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 13 additions and 11 deletions

10
ggml.c
View file

@ -5858,11 +5858,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
// disable BLAS for Q4_0 and Q4_1
// looks like there is no benefit and we only waste a lot of memory
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1) {
return false;
}
//// disable BLAS for Q4_0 and Q4_1
//// looks like there is no benefit and we only waste a lot of memory
//if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1) {
// return false;
//}
//printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);
return true;

View file

@ -44,7 +44,7 @@ enum e_model {
static const size_t MB = 1024*1024;
// computed for n_ctx == 2048
// TODO: dynamically determine thess sizes
// TODO: dynamically determine these sizes
// needs modifications in ggml
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
@ -69,11 +69,13 @@ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
{ MODEL_65B, 5120ull*MB },
};
// this is mostly needed for temporary mul_mat buffers to dequantize the data
// not actually needed if BLAS is disabled
static const std::map<e_model, size_t> MEM_REQ_EVAL = {
{ MODEL_7B, 128ull*MB },
{ MODEL_13B, 128ull*MB },
{ MODEL_30B, 128ull*MB },
{ MODEL_65B, 128ull*MB },
{ MODEL_7B, 768ull*MB },
{ MODEL_13B, 1024ull*MB },
{ MODEL_30B, 1280ull*MB },
{ MODEL_65B, 1536ull*MB },
};
// default hparams (LLaMA 7B)
@ -1034,7 +1036,7 @@ static bool llama_eval_internal(
}
#if 0
printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB, %.3f MB %.3f MB %.3f %.3f %.3f MB\n", __func__,
printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0,
lctx.get_buf_max_mem(0)/1024.0/1024.0,
lctx.get_buf_max_mem(1)/1024.0/1024.0);