partially working, but the blas matmul is broken

This commit is contained in:
Concedo 2023-05-13 11:35:38 +08:00
parent b335f73a60
commit 05cf5f7d6e
8 changed files with 53 additions and 21 deletions

32
ggml.c
View file

@ -1429,6 +1429,20 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
return quantize_fns[i]; return quantize_fns[i];
} }
bool quants_unshuffled = false; //new GGJT_2 is unshuffled, all old ones are shuffled
static const quantize_fns_t quantize_fns_v2[GGML_TYPE_COUNT]; //forward decl
static inline quantize_fns_t get_quantize_fn(size_t i)
{
if(quants_unshuffled)
{
return quantize_fns[i];
}
else
{
return quantize_fns_v2[i];
}
}
// //
// simd mappings // simd mappings
@ -5637,7 +5651,7 @@ static void ggml_compute_forward_dup_f16(
} }
} }
} else if (ggml_is_quantized(dst->type)) { } else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; quantize_row_q_t const quantize_row_q = get_quantize_fn(dst->type).quantize_row_q;
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0; size_t id = 0;
@ -5936,7 +5950,7 @@ static void ggml_compute_forward_dup_f32(
} }
} }
} else if (ggml_is_quantized(dst->type)) { } else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q; quantize_row_q_t const quantize_row_q = get_quantize_fn(dst->type).quantize_row_q;
size_t id = 0; size_t id = 0;
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]); size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
@ -6346,8 +6360,8 @@ static void ggml_compute_forward_add_q_f32(
GGML_ASSERT(ne3 == ne13); GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; dequantize_row_q_t const dequantize_row_q = get_quantize_fn(type).dequantize_row_q;
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; quantize_row_q_t const quantize_row_q = get_quantize_fn(type).quantize_row_q;
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@ -7809,9 +7823,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
GGML_ASSERT(ne3 == ne13); GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; quantize_row_q_t const quantize_row_q_dot = get_quantize_fn(type).quantize_row_q_dot;
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; vec_dot_q_t const vec_dot_q = get_quantize_fn(type).vec_dot_q;
enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; enum ggml_type const vec_dot_type = get_quantize_fn(type).vec_dot_type;
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@ -8138,7 +8152,7 @@ static void ggml_compute_forward_get_rows_q(
const int nc = src0->ne[0]; const int nc = src0->ne[0];
const int nr = ggml_nelements(src1); const int nr = ggml_nelements(src1);
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; dequantize_row_q_t const dequantize_row_q = get_quantize_fn(type).dequantize_row_q;
assert( dst->ne[0] == nc); assert( dst->ne[0] == nc);
assert( dst->ne[1] == nr); assert( dst->ne[1] == nr);
@ -10923,7 +10937,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} else } else
#endif #endif
{ {
const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type; const enum ggml_type type_q = get_quantize_fn(node->src0->type).vec_dot_type;
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q]; cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
} }
} else { } else {

2
ggml.h
View file

@ -895,6 +895,8 @@ extern "C" {
// system info // system info
// //
void SetQuantsUnshuffled(bool unshuffled);
GGML_API int ggml_cpu_has_avx (void); GGML_API int ggml_cpu_has_avx (void);
GGML_API int ggml_cpu_has_avx2 (void); GGML_API int ggml_cpu_has_avx2 (void);
GGML_API int ggml_cpu_has_avx512 (void); GGML_API int ggml_cpu_has_avx512 (void);

View file

@ -1571,6 +1571,11 @@ static void ggml_vec_dot_q5_0_q8_0_v2(const int n, float * restrict s, const voi
static void ggml_vec_dot_q5_1_q8_1_v2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q5_1_q8_1_v2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q8_0_q8_0_v2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q8_0_q8_0_v2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void SetQuantsUnshuffled(bool unshuffle)
{
quants_unshuffled = unshuffle;
}
//TODO: integrate backwards compat //TODO: integrate backwards compat
static const quantize_fns_t quantize_fns_v2[GGML_TYPE_COUNT] = { static const quantize_fns_t quantize_fns_v2[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = { [GGML_TYPE_Q4_0] = {

View file

@ -225,8 +225,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("System Info: %s\n", llama_print_system_info()); printf("System Info: %s\n", llama_print_system_info());
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
//newer format has bit unshuffling
SetQuantsUnshuffled(file_format== FileFormat::GGJT_2);
llama_ctx_params = llama_context_default_params(); llama_ctx_params = llama_context_default_params();
llama_ctx_params.n_ctx = inputs.max_context_length; llama_ctx_params.n_ctx = inputs.max_context_length;
llama_ctx_params.n_parts = -1; llama_ctx_params.n_parts = -1;
@ -243,7 +246,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, modelname.c_str()); fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, modelname.c_str());
return ModelLoadResult::FAIL; return ModelLoadResult::FAIL;
} }
if (file_format < FileFormat::GGJT) if (file_format < FileFormat::GGJT_2)
{ {
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format); printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format);
} }
@ -484,7 +487,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt // tokenize the prompt
std::vector<int> embd_inp; std::vector<int> embd_inp;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
params.prompt.insert(0, 1, ' '); params.prompt.insert(0, 1, ' ');
if (file_format == FileFormat::GGML) if (file_format == FileFormat::GGML)
@ -543,7 +546,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{ {
//for non llama, limit to 256 //for non llama, limit to 256
int bbs = blasbatchsize; int bbs = blasbatchsize;
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT) if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT && file_format != FileFormat::GGJT_2)
{ {
bbs = (blasbatchsize > 256 ? 256 : blasbatchsize); bbs = (blasbatchsize > 256 ? 256 : blasbatchsize);
} }
@ -573,7 +576,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
double time1 = 0, time2 = 0; double time1 = 0, time2 = 0;
int32_t n_vocab = 0; int32_t n_vocab = 0;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
n_vocab = llama_n_vocab(llama_ctx_v1); n_vocab = llama_n_vocab(llama_ctx_v1);
} }
@ -624,7 +627,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if(debugmode) if(debugmode)
{ {
printf("\n[Debug: Dump Input Tokens]\n"); printf("\n[Debug: Dump Input Tokens]\n");
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
for (auto id : embd_inp) for (auto id : embd_inp)
{ {
@ -661,7 +664,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
bool evalres = false; bool evalres = false;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
evalres = (llama_eval(llama_ctx_v1, embd.data(), embdsize, n_past, params.n_threads)==0); evalres = (llama_eval(llama_ctx_v1, embd.data(), embdsize, n_past, params.n_threads)==0);
} }
@ -722,7 +725,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\n"); printf("\n");
} }
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
auto logits = llama_get_logits(llama_ctx_v1); auto logits = llama_get_logits(llama_ctx_v1);
@ -772,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// decrement remaining sampling budget // decrement remaining sampling budget
--remaining_tokens; --remaining_tokens;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT) if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{ {
concat_output += llama_token_to_str(llama_ctx_v1, id); concat_output += llama_token_to_str(llama_ctx_v1, id);
if(unbanTokens && id==llama_token_eos()) if(unbanTokens && id==llama_token_eos())

View file

@ -937,7 +937,8 @@ static void llama_model_load_internal(
if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && if (hparams.ftype != LLAMA_FTYPE_ALL_F32 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) {
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1305)"); printf("\nLegacy LLAMA GGJT compatability changes triggered.\n");
//throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1305)");
} }
} }

View file

@ -145,7 +145,13 @@ void print_tok_vec(std::vector<float> &embd)
} }
else if(magic == 0x67676a74) //v3 format ggjt else if(magic == 0x67676a74) //v3 format ggjt
{ {
fileformat = FileFormat::GGJT; //ggjt by default fileformat = FileFormat::GGJT_2; //ggjt by default
uint32_t temp;
fin.read((char *)&temp, sizeof(temp)); //file version
if(temp==1)
{
fileformat = FileFormat::GGJT;
}
} }
fin.close(); fin.close();

View file

@ -19,6 +19,7 @@ enum FileFormat
GGML=1, // 1=(original llama ggml, alpaca, GPT4ALL, GPTJ header) GGML=1, // 1=(original llama ggml, alpaca, GPT4ALL, GPTJ header)
GGHF=2, // 2=(llama ggmf) GGHF=2, // 2=(llama ggmf)
GGJT=3, // 3=(llama ggjt) GGJT=3, // 3=(llama ggjt)
GGJT_2=4, //newer llama format
GPTJ_1=100, //the very first super old GPTJ format GPTJ_1=100, //the very first super old GPTJ format
GPTJ_2=101, //pygmalion, uses old ggml lib GPTJ_2=101, //pygmalion, uses old ggml lib

View file

@ -352,7 +352,7 @@ bool gpt2_eval(
if (mem_per_token > 0 && (mem_per_token*N*2 + 48u*1024*1024) > buf_size) { if (mem_per_token > 0 && (mem_per_token*N*2 + 48u*1024*1024) > buf_size) {
const size_t buf_size_new = 320u*1024*1024 + 2*(mem_per_token*N); // add 10% to account for ggml object overhead const size_t buf_size_new = 320u*1024*1024 + 2*(mem_per_token*N); // add 10% to account for ggml object overhead
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate // reallocate
if (buf_size_new > buf_size) if (buf_size_new > buf_size)