partially working, but the blas matmul is broken

This commit is contained in:
Concedo 2023-05-13 11:35:38 +08:00
parent b335f73a60
commit 05cf5f7d6e
8 changed files with 53 additions and 21 deletions

32
ggml.c
View file

@ -1429,6 +1429,20 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
return quantize_fns[i];
}
bool quants_unshuffled = false; //new GGJT_2 is unshuffled, all old ones are shuffled
static const quantize_fns_t quantize_fns_v2[GGML_TYPE_COUNT]; //forward decl
static inline quantize_fns_t get_quantize_fn(size_t i)
{
if(quants_unshuffled)
{
return quantize_fns[i];
}
else
{
return quantize_fns_v2[i];
}
}
//
// simd mappings
@ -5637,7 +5651,7 @@ static void ggml_compute_forward_dup_f16(
}
}
} else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
quantize_row_q_t const quantize_row_q = get_quantize_fn(dst->type).quantize_row_q;
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
@ -5936,7 +5950,7 @@ static void ggml_compute_forward_dup_f32(
}
}
} else if (ggml_is_quantized(dst->type)) {
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
quantize_row_q_t const quantize_row_q = get_quantize_fn(dst->type).quantize_row_q;
size_t id = 0;
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
@ -6346,8 +6360,8 @@ static void ggml_compute_forward_add_q_f32(
GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
dequantize_row_q_t const dequantize_row_q = get_quantize_fn(type).dequantize_row_q;
quantize_row_q_t const quantize_row_q = get_quantize_fn(type).quantize_row_q;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@ -7809,9 +7823,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type;
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
quantize_row_q_t const quantize_row_q_dot = get_quantize_fn(type).quantize_row_q_dot;
vec_dot_q_t const vec_dot_q = get_quantize_fn(type).vec_dot_q;
enum ggml_type const vec_dot_type = get_quantize_fn(type).vec_dot_type;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@ -8138,7 +8152,7 @@ static void ggml_compute_forward_get_rows_q(
const int nc = src0->ne[0];
const int nr = ggml_nelements(src1);
const enum ggml_type type = src0->type;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
dequantize_row_q_t const dequantize_row_q = get_quantize_fn(type).dequantize_row_q;
assert( dst->ne[0] == nc);
assert( dst->ne[1] == nr);
@ -10923,7 +10937,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} else
#endif
{
const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
const enum ggml_type type_q = get_quantize_fn(node->src0->type).vec_dot_type;
cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
}
} else {

2
ggml.h
View file

@ -895,6 +895,8 @@ extern "C" {
// system info
//
void SetQuantsUnshuffled(bool unshuffled);
GGML_API int ggml_cpu_has_avx (void);
GGML_API int ggml_cpu_has_avx2 (void);
GGML_API int ggml_cpu_has_avx512 (void);

View file

@ -1571,6 +1571,11 @@ static void ggml_vec_dot_q5_0_q8_0_v2(const int n, float * restrict s, const voi
static void ggml_vec_dot_q5_1_q8_1_v2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q8_0_q8_0_v2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void SetQuantsUnshuffled(bool unshuffle)
{
quants_unshuffled = unshuffle;
}
//TODO: integrate backwards compat
static const quantize_fns_t quantize_fns_v2[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = {

View file

@ -225,8 +225,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("System Info: %s\n", llama_print_system_info());
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
//newer format has bit unshuffling
SetQuantsUnshuffled(file_format== FileFormat::GGJT_2);
llama_ctx_params = llama_context_default_params();
llama_ctx_params.n_ctx = inputs.max_context_length;
llama_ctx_params.n_parts = -1;
@ -243,7 +246,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, modelname.c_str());
return ModelLoadResult::FAIL;
}
if (file_format < FileFormat::GGJT)
if (file_format < FileFormat::GGJT_2)
{
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format);
}
@ -484,7 +487,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt
std::vector<int> embd_inp;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
params.prompt.insert(0, 1, ' ');
if (file_format == FileFormat::GGML)
@ -543,7 +546,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
//for non llama, limit to 256
int bbs = blasbatchsize;
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT)
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT && file_format != FileFormat::GGJT_2)
{
bbs = (blasbatchsize > 256 ? 256 : blasbatchsize);
}
@ -573,7 +576,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
double time1 = 0, time2 = 0;
int32_t n_vocab = 0;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
n_vocab = llama_n_vocab(llama_ctx_v1);
}
@ -624,7 +627,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if(debugmode)
{
printf("\n[Debug: Dump Input Tokens]\n");
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
for (auto id : embd_inp)
{
@ -661,7 +664,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
bool evalres = false;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
evalres = (llama_eval(llama_ctx_v1, embd.data(), embdsize, n_past, params.n_threads)==0);
}
@ -722,7 +725,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\n");
}
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
auto logits = llama_get_logits(llama_ctx_v1);
@ -772,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// decrement remaining sampling budget
--remaining_tokens;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
concat_output += llama_token_to_str(llama_ctx_v1, id);
if(unbanTokens && id==llama_token_eos())

View file

@ -937,7 +937,8 @@ static void llama_model_load_internal(
if (hparams.ftype != LLAMA_FTYPE_ALL_F32 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) {
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1305)");
printf("\nLegacy LLAMA GGJT compatability changes triggered.\n");
//throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1305)");
}
}

View file

@ -145,7 +145,13 @@ void print_tok_vec(std::vector<float> &embd)
}
else if(magic == 0x67676a74) //v3 format ggjt
{
fileformat = FileFormat::GGJT; //ggjt by default
fileformat = FileFormat::GGJT_2; //ggjt by default
uint32_t temp;
fin.read((char *)&temp, sizeof(temp)); //file version
if(temp==1)
{
fileformat = FileFormat::GGJT;
}
}
fin.close();

View file

@ -19,6 +19,7 @@ enum FileFormat
GGML=1, // 1=(original llama ggml, alpaca, GPT4ALL, GPTJ header)
GGHF=2, // 2=(llama ggmf)
GGJT=3, // 3=(llama ggjt)
GGJT_2=4, //newer llama format
GPTJ_1=100, //the very first super old GPTJ format
GPTJ_2=101, //pygmalion, uses old ggml lib

View file

@ -352,7 +352,7 @@ bool gpt2_eval(
if (mem_per_token > 0 && (mem_per_token*N*2 + 48u*1024*1024) > buf_size) {
const size_t buf_size_new = 320u*1024*1024 + 2*(mem_per_token*N); // add 10% to account for ggml object overhead
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
if (buf_size_new > buf_size)