diff --git a/Makefile b/Makefile index 1757ce204..f0a443c0a 100644 --- a/Makefile +++ b/Makefile @@ -31,8 +31,8 @@ endif # # keep standard at C11 and C++11 -CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC -CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC +CFLAGS = -I. -Ofast -DNDEBUG -std=c11 -fPIC +CXXFLAGS = -I. -I./examples -Ofast -DNDEBUG -std=c++11 -fPIC LDFLAGS = #lets try enabling everything @@ -152,7 +152,7 @@ ggml_blas.o: ggml.c ggml.h $(CC) $(CFLAGS) -DGGML_USE_OPENBLAS -c ggml.c -o ggml_blas.o ggml_v1.o: otherarch/ggml_v1.c otherarch/ggml_v1.h - $(CC) $(CFLAGS) -c otherarch/ggml_v1.c -o ggml_v1.o + $(CC) $(CFLAGS) -c otherarch/ggml_v1.c -o ggml_v1.o llama.o: llama.cpp llama.h $(CXX) $(CXXFLAGS) -c llama.cpp -o llama.o @@ -193,6 +193,8 @@ perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o $(CXX) $(CXXFLAGS) examples/embedding/embedding.cpp ggml.o llama.o common.o -o embedding $(LDFLAGS) +gptj: ggml_v1.o + $(CXX) $(CXXFLAGS) otherarch/gptj_v1_main.cpp otherarch/utils.cpp ggml_v1.o -o gptj $(LDFLAGS) # # Tests # diff --git a/expose.cpp b/expose.cpp index 29e06954e..5dec31cee 100644 --- a/expose.cpp +++ b/expose.cpp @@ -24,17 +24,38 @@ extern "C" { //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) - static FileFormat file_format = FAIL; + static FileFormat file_format = FileFormat::BADFORMAT; bool load_model(const load_model_inputs inputs) { std::string model = inputs.model_filename; file_format = check_file_format(model.c_str()); - if(file_format==GPTJ1 || file_format==GPTJ2) + if(file_format==FileFormat::GPTJ1 || file_format==FileFormat::GPTJ2 || file_format==FileFormat::GPTJ3) { - printf("\n---\nIdentified as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format); - return gptj_load_model(inputs, file_format); + printf("\n---\nIdentified as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format); + ModelLoadResult lr = gptj_load_model(inputs, file_format); + if (lr == ModelLoadResult::RETRY_LOAD) + { + file_format = FileFormat::GPTJ2; + printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format); + lr = gptj_load_model(inputs, file_format); + } + if (lr == ModelLoadResult::RETRY_LOAD) + { + file_format = FileFormat::GPTJ3; + printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format); + lr = gptj_load_model(inputs, file_format); + } + + if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD) + { + return false; + } + else + { + return true; + } } else { @@ -45,7 +66,7 @@ extern "C" generation_outputs generate(const generation_inputs inputs, generation_outputs &output) { - if (file_format == GPTJ1 || file_format == GPTJ2) + if (file_format == FileFormat::GPTJ1 || file_format == FileFormat::GPTJ2 || file_format==FileFormat::GPTJ3) { return gptj_generate(inputs, output); } diff --git a/gptj_adapter.cpp b/gptj_adapter.cpp index 2a6ff5b6c..65b340b2d 100644 --- a/gptj_adapter.cpp +++ b/gptj_adapter.cpp @@ -17,7 +17,7 @@ #include "otherarch/gptj_v2.cpp" //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) -static FileFormat file_format = FileFormat::FAIL; +static FileFormat file_format = FileFormat::BADFORMAT; static gpt_vocab vocab; static gptj_model_v1 model_v1; static gptj_model model_v2; @@ -30,9 +30,8 @@ static std::vector current_context_tokens; static size_t mem_per_token = 0; static std::vector logits; -bool gptj_load_model(const load_model_inputs inputs, FileFormat in_file_format) +ModelLoadResult gptj_load_model(const load_model_inputs inputs, FileFormat in_file_format) { - ggml_time_init(); file_format = in_file_format; @@ -40,20 +39,42 @@ bool gptj_load_model(const load_model_inputs inputs, FileFormat in_file_format) n_batch = params.n_batch = inputs.batch_size; modelname = params.model = inputs.model_filename; - if (!legacy_gptj_model_load(params.model, model_v1, vocab)) { - fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); - return false; + if (file_format == FileFormat::GPTJ1 || file_format == FileFormat::GPTJ2) + { + ModelLoadResult res = legacy_gptj_model_load(params.model, model_v1, vocab, file_format); + if(res==ModelLoadResult::FAIL) + { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return res; + } + else if(res==ModelLoadResult::RETRY_LOAD) + { + printf("\nTensor Transposition Detected! Retrying GPT-J model loading..."); + return res; + } + // determine the required inference memory per token: + legacy_gptj_eval(model_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); + return ModelLoadResult::SUCCESS; } - - if (file_format != FileFormat::GPTJ2) + else { - printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format); + ModelLoadResult loadresult = gptj_model_load(params.model, model_v2, vocab); + if (loadresult == ModelLoadResult::FAIL) + { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return loadresult; + } + else if (loadresult == ModelLoadResult::RETRY_LOAD) + { + printf("\nTensor Transposition Detected! Retrying GPT-J model loading..."); + return loadresult; + } + + // determine the required inference memory per token: + gptj_eval(model_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + return ModelLoadResult::SUCCESS; } - - // determine the required inference memory per token: - legacy_gptj_eval(model_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - - return true; + } @@ -82,9 +103,10 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); //truncate to front of the prompt if its too long - if (embd_inp.size() + params.n_predict > model_v1.hparams.n_ctx) + auto nctx = ( (file_format == FileFormat::GPTJ1||file_format == FileFormat::GPTJ2)? model_v1.hparams.n_ctx:model_v2.hparams.n_ctx); + if (embd_inp.size() + params.n_predict > nctx) { - int offset = embd_inp.size() - model_v1.hparams.n_ctx + params.n_predict; + int offset = embd_inp.size() - nctx + params.n_predict; embd_inp = std::vector(embd_inp.begin() + offset, embd_inp.end()); } @@ -114,7 +136,7 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past); //if using BLAS and prompt is big enough, switch to single thread and use a huge batch - bool blasmode = false;// (embd_inp.size() >= 32 && ggml_cpu_has_blas()); + bool blasmode = false; //(embd_inp.size() >= 32 && ggml_cpu_has_blas()); int original_batch = params.n_batch; int original_threads = params.n_threads; if (blasmode) @@ -135,7 +157,7 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp timer_start(); double time1 = 0, time2 = 0; unsigned int embd_inp_size = embd_inp.size(); - const int n_vocab = model_v1.hparams.n_vocab; + const int n_vocab = ((file_format == FileFormat::GPTJ1||file_format == FileFormat::GPTJ2)? model_v1.hparams.n_vocab:model_v2.hparams.n_vocab); printf("\n"); @@ -156,7 +178,15 @@ generation_outputs gptj_generate(const generation_inputs inputs, generation_outp printf("\rGenerating (%d / %d tokens)", (1 + params.n_predict - remaining_tokens), params.n_predict); } - if (!legacy_gptj_eval(model_v1, params.n_threads, n_past, embd, logits, mem_per_token)) + bool evalres = false; + if(file_format==FileFormat::GPTJ1 || file_format==FileFormat::GPTJ2) + { + evalres = legacy_gptj_eval(model_v1, params.n_threads, n_past, embd, logits, mem_per_token, file_format); + }else + { + evalres = gptj_eval(model_v2, params.n_threads, n_past, embd, logits, mem_per_token); + } + if (!evalres) { fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); diff --git a/llama_adapter.cpp b/llama_adapter.cpp index 39f79395c..9c2ef9ce8 100644 --- a/llama_adapter.cpp +++ b/llama_adapter.cpp @@ -16,7 +16,7 @@ #include "llamaextra.cpp" //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) -static FileFormat file_format = FileFormat::FAIL; +static FileFormat file_format = FileFormat::BADFORMAT; static llama_context_params ctx_params; static gpt_params params; static int n_past = 0; diff --git a/llamacpp.dll b/llamacpp.dll index 3c5b5642f..f2018ac29 100644 Binary files a/llamacpp.dll and b/llamacpp.dll differ diff --git a/llamacpp_blas.dll b/llamacpp_blas.dll index 91355063f..2be1b3d5d 100644 Binary files a/llamacpp_blas.dll and b/llamacpp_blas.dll differ diff --git a/llamacpp_for_kobold.py b/llamacpp_for_kobold.py index c64819957..063bf9bdc 100644 --- a/llamacpp_for_kobold.py +++ b/llamacpp_for_kobold.py @@ -37,12 +37,17 @@ use_blas = False # if true, uses OpenBLAS for acceleration. libopenblas.dll must def init_library(): global handle, use_blas - dir_path = os.path.dirname(os.path.realpath(__file__)) + libname = "" if use_blas: - #OpenBLAS should provide about a 2x speedup on prompt ingestion if compatible. - handle = ctypes.CDLL(os.path.join(dir_path, "llamacpp_blas.dll")) + libname = "llamacpp_blas.dll" else: - handle = ctypes.CDLL(os.path.join(dir_path, "llamacpp.dll")) + libname = "llamacpp.dll" + + print("Initializing dynamic library: " + libname) + dir_path = os.path.dirname(os.path.realpath(__file__)) + + #OpenBLAS should provide about a 2x speedup on prompt ingestion if compatible. + handle = ctypes.CDLL(os.path.join(dir_path, libname )) handle.load_model.argtypes = [load_model_inputs] handle.load_model.restype = ctypes.c_bool diff --git a/main.exe b/main.exe index 8cbbdf330..b1835a8de 100644 Binary files a/main.exe and b/main.exe differ diff --git a/model_adapter.cpp b/model_adapter.cpp index e4edff05e..e711f0508 100644 --- a/model_adapter.cpp +++ b/model_adapter.cpp @@ -49,10 +49,10 @@ void print_tok_vec(std::vector &embd) fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size()); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); - return FileFormat::FAIL; + return FileFormat::BADFORMAT; } - FileFormat fileformat = FileFormat::FAIL; + FileFormat fileformat = FileFormat::BADFORMAT; uint32_t magic; fin.read((char *) &magic, sizeof(magic)); if (magic == 0x67676d6c) { //v1 format ggml, alpaca, old gptj and gpt2 models diff --git a/model_adapter.h b/model_adapter.h index 1540eee52..0814d1d29 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -13,23 +13,30 @@ #include "expose.h" -//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) enum FileFormat { - FAIL=0, - GGML=1, - GGHF=2, - GGJT=3, + BADFORMAT=0, //unknown, uninit, or failed to load + GGML=1, // 1=(original llama ggml, alpaca, GPT4ALL, GPTJ header) + GGHF=2, // 2=(llama ggmf) + GGJT=3, // 3=(llama ggjt) - GPTJ1=100, - GPTJ2=101, + GPTJ1=100, //the very first super old GPTJ format + GPTJ2=101, //pygmalion, uses old ggml lib + GPTJ3=102, //uses new ggml lib GPT2=200, }; +enum ModelLoadResult +{ + FAIL = 0, + SUCCESS = 1, + RETRY_LOAD = 2, //used if it's suspected that the model is an older format +}; + bool llama_load_model(const load_model_inputs inputs, FileFormat file_format); generation_outputs llama_generate(const generation_inputs inputs, generation_outputs &output); -bool gptj_load_model(const load_model_inputs inputs, FileFormat in_file_format); +ModelLoadResult gptj_load_model(const load_model_inputs inputs, FileFormat in_file_format); generation_outputs gptj_generate(const generation_inputs inputs, generation_outputs &output); diff --git a/otherarch/ggml_v1.c b/otherarch/ggml_v1.c index fd5e1d54f..39b60d04a 100644 --- a/otherarch/ggml_v1.c +++ b/otherarch/ggml_v1.c @@ -13,6 +13,7 @@ #include #include #include +#include // if C99 - static_assert is noop // ref: https://stackoverflow.com/a/53923785/4039976 @@ -30,9 +31,6 @@ #include #endif -// Need this to compile with Visual Studio 2017 -#define restrict __restrict - typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; @@ -351,6 +349,249 @@ int64_t ggml_v1_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +// +// quantization +// + +#define QK 32 + +// method 5 +// blocks of QK elements +// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) +void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + float * restrict pd = (float *) (y); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[QK/2]; + +#if __ARM_NEON +#if QK == 32 + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + amax = MAX( + MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)), + MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = d; + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f)); + const int32x4_t vi = vcvtq_s32_f32(vf); + + pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4); + pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#else +#error "not implemented for QK" +#endif +#elif defined(__wasm_simd128__) +#if QK == 32 + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]); + + amax = MAX( + MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0/d : 0.0; + + pd[i] = d; + + for (int l = 0; l < 8; l++) { + const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id)); + const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf); + + pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4); + pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4); + } + + memcpy(pb + i*16, pp, sizeof(pp)); + } +#else +#error "not implemented for QK" +#endif +#else + // scalar + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0f/d : 0.0f; + + pd[i] = d; + + for (int l = 0; l < QK; l += 2) { + const float v0 = x[i*QK + l + 0]*id; + const float v1 = x[i*QK + l + 1]*id; + + const uint8_t vi0 = ((int8_t) (round(v0))) + 8; + const uint8_t vi1 = ((int8_t) (round(v1))) + 8; + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +#endif +} + +// method 4 +// blocks of QK elements +// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) +void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + float * restrict pm = (float *) (y); + float * restrict pd = (float *) (pm + nb); + uint8_t * restrict pb = (uint8_t *) (pd + nb); + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + pm[i] = min; + pd[i] = d; + + for (int l = 0; l < QK; l += 2) { + const float v0 = (x[i*QK + l + 0] - min)*id; + const float v1 = (x[i*QK + l + 1] - min)*id; + + const uint8_t vi0 = round(v0); + const uint8_t vi1 = round(v1); + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(pb + i*QK/2, pp, sizeof(pp)); + } +} + +// TODO: vectorize +void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + const float * restrict pd = (const float *) (x); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + + // scalar + for (int i = 0; i < nb; i++) { + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + } + } +} + +void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { + assert(k % QK == 0); + + const int nb = k / QK; + + const float * restrict pm = (const float *) (x); + const float * restrict pd = (const float *) (pm + nb); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + + for (int i = 0; i < nb; i++) { + const float m = pm[i]; + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK + l + 0] = v0; + y[i*QK + l + 1] = v1; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + } + } +} + // // simd mappings // @@ -928,6 +1169,264 @@ inline static void ggml_v1_vec_dot_f16(const int n, float * restrict s, ggml_v1_ *s = sumf; } +inline static void ggml_v1_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = n / QK; + + assert(n % QK == 0); + assert(nb % 2 == 0); + + const float * restrict pd0 = (const float *) x; + const float * restrict pd1 = (const float *) y; + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#ifdef __ARM_NEON +#if QK == 32 + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int i = 0; i < nb; i += 2) { + const float d0_0 = pd0[i + 0]; + const float d1_0 = pd1[i + 0]; + const float d0_1 = pd0[i + 1]; + const float d1_1 = pd1[i + 1]; + + //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1); + + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(p0); + const uint8x16_t v1_0 = vld1q_u8(p1); + const uint8x16_t v0_1 = vld1q_u8(p0 + 16); + const uint8x16_t v1_1 = vld1q_u8(p1 + 16); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); + const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b)); + + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4)); + + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); + const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b)); + + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); + + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); + + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); + + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); + + // dot product into int16x8_t + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); + const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); + + const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); + const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); + + // scalar +#if defined(__ARM_FEATURE_QRDMX) + sum0 += d0_0*d1_0*vaddvq_s16(p_0); + sum1 += d0_1*d1_1*vaddvq_s16(p_1); +#else + sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); + sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); +#endif + } + + sumf = sum0 + sum1; +#else +#error "not implemented for QK" +#endif +#elif defined(__wasm_simd128__) +#if QK == 32 + // wasm simd + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int i = 0; i < nb; i += 2) { + const float d0_0 = pd0[i + 0]; + const float d0_1 = pd0[i + 1]; + const float d1_0 = pd1[i + 0]; + const float d1_1 = pd1[i + 1]; + + const uint8_t * restrict p0 = pb0 + i*16; + const uint8_t * restrict p1 = pb1 + i*16; + + const v128_t m4b = wasm_u8x16_splat(0xf); + const v128_t s8b = wasm_i8x16_splat(0x8); + + const v128_t v0_0 = wasm_v128_load(p0); + const v128_t v0_1 = wasm_v128_load(p0 + 16); + const v128_t v1_0 = wasm_v128_load(p1); + const v128_t v1_1 = wasm_v128_load(p1 + 16); + + // 4-bit -> 8-bit + const v128_t v0_0l = wasm_v128_and(v0_0, m4b); + const v128_t v1_0l = wasm_v128_and(v1_0, m4b); + + const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4); + + const v128_t v0_1l = wasm_v128_and(v0_1, m4b); + const v128_t v1_1l = wasm_v128_and(v1_1, m4b); + + const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4); + + // sub 8 + const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b); + + const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b); + + const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b); + + const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b); + + // dot product into int16x8_t + const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls)); + const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls)); + + const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs)); + const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs)); + + const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls)); + const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls)); + + const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs)); + const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs)); + + const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h); + const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h); + + const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h); + const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h); + + const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0); + const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1); + + sum0 += d0_0*d1_0*( + wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) + + wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) + + wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) + + wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7)); + sum1 += d0_1*d1_1*( + wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) + + wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) + + wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) + + wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7)); + } + + sumf = sum0 + sum1; +#else +#error "not implemented for QK" +#endif +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = pd0[i]; + const float d1 = pd1[i]; + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*((int8_t) (v0 & 0xf) - 8); + const float f1 = d0*((int8_t) (v0 >> 4) - 8); + + const float f2 = d1*((int8_t) (v1 & 0xf) - 8); + const float f3 = d1*((int8_t) (v1 >> 4) - 8); + + sumf += f0*f2 + f1*f3; + } + } +#endif + + *s = sumf; +} + +inline static void ggml_v1_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const int nb = n / QK; + + const float * restrict pm0 = (const float *) x; + const float * restrict pm1 = (const float *) y; + + const float * restrict pd0 = (const float *) (pm0 + nb); + const float * restrict pd1 = (const float *) (pm1 + nb); + + const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); + const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + + float sumf = 0.0; + +#if 1 + // scalar + for (int i = 0; i < nb; i++) { + const float m0 = pm0[i]; + const float m1 = pm1[i]; + + const float d0 = pd0[i]; + const float d1 = pd1[i]; + + const uint8_t * restrict p0 = pb0 + i*QK/2; + const uint8_t * restrict p1 = pb1 + i*QK/2; + + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + const uint8_t v1 = p1[j]; + + const float f0 = d0*(v0 & 0xf) + m0; + const float f1 = d0*(v0 >> 4) + m0; + + const float f2 = d1*(v1 & 0xf) + m1; + const float f3 = d1*(v1 >> 4) + m1; + + sumf += f0*f2 + f1*f3; + } + } +#endif + + *s = sumf; +} + // compute GGML_V1_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes inline static void ggml_v1_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_v1_fp16_t * restrict y) { @@ -1045,6 +1544,134 @@ inline static void ggml_v1_vec_mad_f16(const int n, ggml_v1_fp16_t * restrict y, #endif } +inline static void ggml_v1_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) { + assert(n % QK == 0); + + const int nb = n / QK; + + const float * restrict pd = (const float *) (x); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + +#if __ARM_NEON +#if QK == 32 + for (int i = 0; i < nb; ++i) { + const float d0 = pd[i]*v; + + const uint8_t * restrict pp = pb + i*16; + + const uint8x8_t m4b = vdup_n_u8(0xf); + const int8x8_t s8b = vdup_n_s8(0x8); + + const float32x4_t vd = vdupq_n_f32(d0); + + for (int j = 0; j < 2; j++) { + const uint8x8_t vx = vld1_u8(pp + j*8); + + const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b)); + const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4)); + + // sub 8 + const int8x8_t vxls = vsub_s8(vxl, s8b); + const int8x8_t vxhs = vsub_s8(vxh, s8b); + + //const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0]; + //const int8x8_t vxht = vzip_s8(vxls, vxhs)[1]; + const int8x8_t vxlt = vzip1_s8(vxls, vxhs); + const int8x8_t vxht = vzip2_s8(vxls, vxhs); + + const int8x16_t vxq = vcombine_s8(vxlt, vxht); + + // convert to 2x int16x8_t + const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq)); + const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq)); + + // convert to 4x float32x4_t + const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0))); + const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0))); + const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1))); + const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1))); + + const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0); + const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4); + const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8); + const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12); + + const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd); + const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd); + const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd); + const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd); + + vst1q_f32(y + i*32 + j*16 + 0, vr0); + vst1q_f32(y + i*32 + j*16 + 4, vr1); + vst1q_f32(y + i*32 + j*16 + 8, vr2); + vst1q_f32(y + i*32 + j*16 + 12, vr3); + } + } +#endif +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK + l + 0] += v0*v; + y[i*QK + l + 1] += v1*v; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + assert(!isinf(y[i*QK + l + 0])); + assert(!isinf(y[i*QK + l + 1])); + } + } +#endif +} + +inline static void ggml_v1_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) { + assert(n % QK == 0); + + const int nb = n / QK; + + const float * restrict pm = (const float *) (x); + const float * restrict pd = (const float *) (pm + nb); + const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + + for (int i = 0; i < nb; i++) { + const float m = pm[i]; + const float d = pd[i]; + + const uint8_t * restrict pp = pb + i*QK/2; + + for (int l = 0; l < QK; l += 2) { + const uint8_t vi = pp[l/2]; + + const uint8_t vi0 = vi & 0xf; + const uint8_t vi1 = vi >> 4; + + const float v0 = d*vi0 + m; + const float v1 = d*vi1 + m; + + y[i*QK + l + 0] += v0*v; + y[i*QK + l + 1] += v1*v; + + assert(!isnan(y[i*QK + l + 0])); + assert(!isnan(y[i*QK + l + 1])); + assert(!isinf(y[i*QK + l + 0])); + assert(!isinf(y[i*QK + l + 1])); + //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); + } + } +} + //inline static void ggml_v1_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_v1_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_V1_SIMD) @@ -1168,7 +1795,21 @@ inline static void ggml_v1_vec_norm_inv_f32(const int n, float * s, const float // data types // +static const int GGML_V1_BLCK_SIZE[GGML_V1_TYPE_COUNT] = { + QK, + QK, + 1, + 1, + 1, + 1, + 1, +}; + +static_assert(GGML_V1_TYPE_COUNT == 7, "GGML_V1_TYPE_COUNT != 5"); + static const size_t GGML_V1_TYPE_SIZE[GGML_V1_TYPE_COUNT] = { + sizeof(float ) + QK/2, + sizeof(float )*2 + QK/2, sizeof(int8_t ), sizeof(int16_t), sizeof(int32_t), @@ -1176,6 +1817,9 @@ static const size_t GGML_V1_TYPE_SIZE[GGML_V1_TYPE_COUNT] = { sizeof(float ), }; +// don't forget to update the array above when adding new types +static_assert(GGML_V1_TYPE_COUNT == 7, "GGML_V1_TYPE_COUNT != 5"); + static const char * GGML_V1_OP_LABEL[GGML_V1_OP_COUNT] = { "NONE", @@ -1216,6 +1860,8 @@ static const char * GGML_V1_OP_LABEL[GGML_V1_OP_COUNT] = { "FLASH_FF", }; +static_assert(GGML_V1_OP_COUNT == 33, "GGML_V1_OP_COUNT != 33"); + static const char * GGML_V1_OP_SYMBOL[GGML_V1_OP_COUNT] = { "none", @@ -1256,6 +1902,8 @@ static const char * GGML_V1_OP_SYMBOL[GGML_V1_OP_COUNT] = { "flash_ff(x)", }; +static_assert(GGML_V1_OP_COUNT == 33, "GGML_V1_OP_COUNT != 33"); + // // ggml object // @@ -1383,13 +2031,21 @@ int ggml_v1_nrows(const struct ggml_v1_tensor * tensor) { size_t ggml_v1_nbytes(const struct ggml_v1_tensor * tensor) { static_assert(GGML_V1_MAX_DIMS == 4, "GGML_V1_MAX_DIMS is not 4 - update this function"); - return ggml_v1_nelements(tensor)*GGML_V1_TYPE_SIZE[tensor->type]; + return (ggml_v1_nelements(tensor)*GGML_V1_TYPE_SIZE[tensor->type])/GGML_V1_BLCK_SIZE[tensor->type]; +} + +int ggml_v1_blck_size(enum ggml_v1_type type) { + return GGML_V1_BLCK_SIZE[type]; } size_t ggml_v1_type_size(enum ggml_v1_type type) { return GGML_V1_TYPE_SIZE[type]; } +float ggml_v1_type_sizef(enum ggml_v1_type type) { + return ((float)(GGML_V1_TYPE_SIZE[type]))/GGML_V1_BLCK_SIZE[type]; +} + size_t ggml_v1_element_size(const struct ggml_v1_tensor * tensor) { return GGML_V1_TYPE_SIZE[tensor->type]; } @@ -1426,7 +2082,7 @@ static inline bool ggml_v1_is_contiguous(const struct ggml_v1_tensor * tensor) { return tensor->nb[0] == GGML_V1_TYPE_SIZE[tensor->type] && - tensor->nb[1] == tensor->nb[0]*tensor->ne[0] && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_V1_BLCK_SIZE[tensor->type] && tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } @@ -1488,9 +2144,6 @@ struct ggml_v1_context * ggml_v1_init(struct ggml_v1_init_params params) { static bool is_first_call = true; if (is_first_call) { - // initialize time system (required on Windows) - ggml_v1_time_init(); - // initialize GELU, EXP and F32 tables { const uint64_t t_start = ggml_v1_time_us(); UNUSED(t_start); @@ -1629,8 +2282,8 @@ struct ggml_v1_tensor * ggml_v1_new_tensor_impl( size_t size_needed = 0; if (data == NULL) { - size_needed += GGML_V1_TYPE_SIZE[type]; - for (int i = 0; i < n_dims; i++) { + size_needed += GGML_V1_TYPE_SIZE[type]*(ne[0]/GGML_V1_BLCK_SIZE[type]); + for (int i = 1; i < n_dims; i++) { size_needed *= ne[i]; } // align to GGML_V1_MEM_ALIGN @@ -1723,7 +2376,8 @@ struct ggml_v1_tensor * ggml_v1_new_tensor_impl( } result->nb[0] = GGML_V1_TYPE_SIZE[type]; - for (int i = 1; i < GGML_V1_MAX_DIMS; i++) { + result->nb[1] = result->nb[0]*(result->ne[0]/GGML_V1_BLCK_SIZE[type]); + for (int i = 2; i < GGML_V1_MAX_DIMS; i++) { result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; } @@ -1820,6 +2474,14 @@ struct ggml_v1_tensor * ggml_v1_set_i32 (struct ggml_v1_tensor * tensor, int32_t char * const data = tensor->data; switch (tensor->type) { + case GGML_V1_TYPE_Q4_0: + { + GGML_V1_ASSERT(false); + } break; + case GGML_V1_TYPE_Q4_1: + { + GGML_V1_ASSERT(false); + } break; case GGML_V1_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -1857,7 +2519,7 @@ struct ggml_v1_tensor * ggml_v1_set_i32 (struct ggml_v1_tensor * tensor, int32_t } break; case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } @@ -1872,6 +2534,14 @@ struct ggml_v1_tensor * ggml_v1_set_f32(struct ggml_v1_tensor * tensor, float va char * const data = tensor->data; switch (tensor->type) { + case GGML_V1_TYPE_Q4_0: + { + GGML_V1_ASSERT(false); + } break; + case GGML_V1_TYPE_Q4_1: + { + GGML_V1_ASSERT(false); + } break; case GGML_V1_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -1909,7 +2579,7 @@ struct ggml_v1_tensor * ggml_v1_set_f32(struct ggml_v1_tensor * tensor, float va } break; case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } @@ -1918,6 +2588,14 @@ struct ggml_v1_tensor * ggml_v1_set_f32(struct ggml_v1_tensor * tensor, float va int32_t ggml_v1_get_i32_1d(const struct ggml_v1_tensor * tensor, int i) { switch (tensor->type) { + case GGML_V1_TYPE_Q4_0: + { + GGML_V1_ASSERT(false); + } break; + case GGML_V1_TYPE_Q4_1: + { + GGML_V1_ASSERT(false); + } break; case GGML_V1_TYPE_I8: { GGML_V1_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -1954,6 +2632,14 @@ int32_t ggml_v1_get_i32_1d(const struct ggml_v1_tensor * tensor, int i) { void ggml_v1_set_i32_1d(const struct ggml_v1_tensor * tensor, int i, int32_t value) { switch (tensor->type) { + case GGML_V1_TYPE_Q4_0: + { + GGML_V1_ASSERT(false); + } break; + case GGML_V1_TYPE_Q4_1: + { + GGML_V1_ASSERT(false); + } break; case GGML_V1_TYPE_I8: { GGML_V1_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -1988,6 +2674,14 @@ void ggml_v1_set_i32_1d(const struct ggml_v1_tensor * tensor, int i, int32_t val float ggml_v1_get_f32_1d(const struct ggml_v1_tensor * tensor, int i) { switch (tensor->type) { + case GGML_V1_TYPE_Q4_0: + { + GGML_V1_ASSERT(false); + } break; + case GGML_V1_TYPE_Q4_1: + { + GGML_V1_ASSERT(false); + } break; case GGML_V1_TYPE_I8: { GGML_V1_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -2024,6 +2718,14 @@ float ggml_v1_get_f32_1d(const struct ggml_v1_tensor * tensor, int i) { void ggml_v1_set_f32_1d(const struct ggml_v1_tensor * tensor, int i, float value) { switch (tensor->type) { + case GGML_V1_TYPE_Q4_0: + { + GGML_V1_ASSERT(false); + } break; + case GGML_V1_TYPE_Q4_1: + { + GGML_V1_ASSERT(false); + } break; case GGML_V1_TYPE_I8: { GGML_V1_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -2114,7 +2816,7 @@ struct ggml_v1_tensor * ggml_v1_add_impl( struct ggml_v1_tensor * a, struct ggml_v1_tensor * b, bool inplace) { - assert(ggml_v1_are_same_shape(a, b)); + GGML_V1_ASSERT(ggml_v1_are_same_shape(a, b)); bool is_node = false; @@ -2153,7 +2855,7 @@ struct ggml_v1_tensor * ggml_v1_sub_impl( struct ggml_v1_tensor * a, struct ggml_v1_tensor * b, bool inplace) { - assert(ggml_v1_are_same_shape(a, b)); + GGML_V1_ASSERT(ggml_v1_are_same_shape(a, b)); bool is_node = false; @@ -2192,7 +2894,7 @@ struct ggml_v1_tensor * ggml_v1_mul_impl( struct ggml_v1_tensor * a, struct ggml_v1_tensor * b, bool inplace) { - assert(ggml_v1_are_same_shape(a, b)); + GGML_V1_ASSERT(ggml_v1_are_same_shape(a, b)); bool is_node = false; @@ -2201,7 +2903,7 @@ struct ggml_v1_tensor * ggml_v1_mul_impl( } if (inplace) { - assert(is_node == false); + GGML_V1_ASSERT(is_node == false); } struct ggml_v1_tensor * result = inplace ? ggml_v1_view_tensor(ctx, a) : ggml_v1_dup_tensor(ctx, a); @@ -2235,7 +2937,7 @@ struct ggml_v1_tensor * ggml_v1_div_impl( struct ggml_v1_tensor * a, struct ggml_v1_tensor * b, bool inplace) { - assert(ggml_v1_are_same_shape(a, b)); + GGML_V1_ASSERT(ggml_v1_are_same_shape(a, b)); bool is_node = false; @@ -2244,7 +2946,7 @@ struct ggml_v1_tensor * ggml_v1_div_impl( } if (inplace) { - assert(is_node == false); + GGML_V1_ASSERT(is_node == false); } struct ggml_v1_tensor * result = inplace ? ggml_v1_view_tensor(ctx, a) : ggml_v1_dup_tensor(ctx, a); @@ -2368,7 +3070,7 @@ struct ggml_v1_tensor * ggml_v1_mean( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement + GGML_V1_ASSERT(false); // TODO: implement is_node = true; } @@ -2389,7 +3091,7 @@ struct ggml_v1_tensor * ggml_v1_repeat( struct ggml_v1_context * ctx, struct ggml_v1_tensor * a, struct ggml_v1_tensor * b) { - assert(ggml_v1_can_repeat(a, b)); + GGML_V1_ASSERT(ggml_v1_can_repeat(a, b)); bool is_node = false; @@ -2625,7 +3327,7 @@ struct ggml_v1_tensor * ggml_v1_norm_impl( bool is_node = false; if (!inplace && (a->grad)) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2657,7 +3359,7 @@ struct ggml_v1_tensor * ggml_v1_mul_mat( struct ggml_v1_context * ctx, struct ggml_v1_tensor * a, struct ggml_v1_tensor * b) { - assert(ggml_v1_can_mul_mat(a, b)); + GGML_V1_ASSERT(ggml_v1_can_mul_mat(a, b)); bool is_node = false; @@ -2683,13 +3385,13 @@ struct ggml_v1_tensor * ggml_v1_scale_impl( struct ggml_v1_tensor * a, struct ggml_v1_tensor * b, bool inplace) { - assert(ggml_v1_is_scalar(b)); - assert(ggml_v1_is_padded_1d(a)); + GGML_V1_ASSERT(ggml_v1_is_scalar(b)); + GGML_V1_ASSERT(ggml_v1_is_padded_1d(a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2726,12 +3428,12 @@ struct ggml_v1_tensor * ggml_v1_cpy_impl( struct ggml_v1_tensor * a, struct ggml_v1_tensor * b, bool inplace) { - assert(ggml_v1_nelements(a) == ggml_v1_nelements(b)); + GGML_V1_ASSERT(ggml_v1_nelements(a) == ggml_v1_nelements(b)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2766,14 +3468,14 @@ struct ggml_v1_tensor * ggml_v1_reshape( struct ggml_v1_context * ctx, struct ggml_v1_tensor * a, struct ggml_v1_tensor * b) { - assert(ggml_v1_is_contiguous(a)); - assert(ggml_v1_is_contiguous(b)); - assert(ggml_v1_nelements(a) == ggml_v1_nelements(b)); + GGML_V1_ASSERT(ggml_v1_is_contiguous(a)); + GGML_V1_ASSERT(ggml_v1_is_contiguous(b)); + GGML_V1_ASSERT(ggml_v1_nelements(a) == ggml_v1_nelements(b)); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2792,13 +3494,13 @@ struct ggml_v1_tensor * ggml_v1_reshape_2d( struct ggml_v1_tensor * a, int ne0, int ne1) { - assert(ggml_v1_is_contiguous(a)); - assert(ggml_v1_nelements(a) == ne0*ne1); + GGML_V1_ASSERT(ggml_v1_is_contiguous(a)); + GGML_V1_ASSERT(ggml_v1_nelements(a) == ne0*ne1); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2819,13 +3521,13 @@ struct ggml_v1_tensor * ggml_v1_reshape_3d( int ne0, int ne1, int ne2) { - assert(ggml_v1_is_contiguous(a)); - assert(ggml_v1_nelements(a) == ne0*ne1*ne2); + GGML_V1_ASSERT(ggml_v1_is_contiguous(a)); + GGML_V1_ASSERT(ggml_v1_nelements(a) == ne0*ne1*ne2); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2848,7 +3550,7 @@ struct ggml_v1_tensor * ggml_v1_view_1d( int ne0, size_t offset) { if (a->grad) { - assert(false); // gradient propagation is not supported + GGML_V1_ASSERT(false); // gradient propagation is not supported } struct ggml_v1_tensor * result = ggml_v1_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); @@ -2871,7 +3573,7 @@ struct ggml_v1_tensor * ggml_v1_view_2d( size_t nb1, size_t offset) { if (a->grad) { - assert(false); // gradient propagation is not supported + GGML_V1_ASSERT(false); // gradient propagation is not supported } const int ne[GGML_V1_MAX_DIMS] = { ne0, ne1, 1, 1 }; @@ -2899,22 +3601,22 @@ struct ggml_v1_tensor * ggml_v1_permute( int axis1, int axis2, int axis3) { - assert(axis0 >= 0 && axis0 < GGML_V1_MAX_DIMS); - assert(axis1 >= 0 && axis1 < GGML_V1_MAX_DIMS); - assert(axis2 >= 0 && axis2 < GGML_V1_MAX_DIMS); - assert(axis3 >= 0 && axis3 < GGML_V1_MAX_DIMS); + GGML_V1_ASSERT(axis0 >= 0 && axis0 < GGML_V1_MAX_DIMS); + GGML_V1_ASSERT(axis1 >= 0 && axis1 < GGML_V1_MAX_DIMS); + GGML_V1_ASSERT(axis2 >= 0 && axis2 < GGML_V1_MAX_DIMS); + GGML_V1_ASSERT(axis3 >= 0 && axis3 < GGML_V1_MAX_DIMS); - assert(axis0 != axis1); - assert(axis0 != axis2); - assert(axis0 != axis3); - assert(axis1 != axis2); - assert(axis1 != axis3); - assert(axis2 != axis3); + GGML_V1_ASSERT(axis0 != axis1); + GGML_V1_ASSERT(axis0 != axis2); + GGML_V1_ASSERT(axis0 != axis3); + GGML_V1_ASSERT(axis1 != axis2); + GGML_V1_ASSERT(axis1 != axis3); + GGML_V1_ASSERT(axis2 != axis3); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2959,7 +3661,7 @@ struct ggml_v1_tensor * ggml_v1_transpose( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -2985,12 +3687,12 @@ struct ggml_v1_tensor * ggml_v1_get_rows( struct ggml_v1_context * ctx, struct ggml_v1_tensor * a, struct ggml_v1_tensor * b) { - assert(ggml_v1_is_matrix(a) && ggml_v1_is_vector(b) && b->type == GGML_V1_TYPE_I32); + GGML_V1_ASSERT(ggml_v1_is_matrix(a) && ggml_v1_is_vector(b) && b->type == GGML_V1_TYPE_I32); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3015,7 +3717,7 @@ struct ggml_v1_tensor * ggml_v1_diag_mask_inf( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3040,7 +3742,7 @@ struct ggml_v1_tensor * ggml_v1_soft_max( bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3064,11 +3766,11 @@ struct ggml_v1_tensor * ggml_v1_rope( int n_past, int n_dims, int mode) { - assert(n_past >= 0); + GGML_V1_ASSERT(n_past >= 0); bool is_node = false; if (a->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3095,13 +3797,13 @@ struct ggml_v1_tensor * ggml_v1_conv_1d_1s( struct ggml_v1_context * ctx, struct ggml_v1_tensor * a, struct ggml_v1_tensor * b) { - assert(ggml_v1_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); + GGML_V1_ASSERT(ggml_v1_is_matrix(b)); + GGML_V1_ASSERT(a->ne[1] == b->ne[1]); + GGML_V1_ASSERT(a->ne[3] == 1); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3122,13 +3824,13 @@ struct ggml_v1_tensor * ggml_v1_conv_1d_2s( struct ggml_v1_context * ctx, struct ggml_v1_tensor * a, struct ggml_v1_tensor * b) { - assert(ggml_v1_is_matrix(b)); - assert(a->ne[1] == b->ne[1]); - assert(a->ne[3] == 1); + GGML_V1_ASSERT(ggml_v1_is_matrix(b)); + GGML_V1_ASSERT(a->ne[1] == b->ne[1]); + GGML_V1_ASSERT(a->ne[3] == 1); bool is_node = false; if (a->grad || b->grad) { - assert(false); // TODO: implement backward + GGML_V1_ASSERT(false); // TODO: implement backward is_node = true; } @@ -3151,7 +3853,7 @@ struct ggml_v1_tensor * ggml_v1_flash_attn( struct ggml_v1_tensor * k, struct ggml_v1_tensor * v, bool masked) { - assert(ggml_v1_can_mul_mat(k, q)); + GGML_V1_ASSERT(ggml_v1_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) bool is_node = false; @@ -3183,7 +3885,7 @@ struct ggml_v1_tensor * ggml_v1_flash_ff( struct ggml_v1_tensor * b1, struct ggml_v1_tensor * c0, struct ggml_v1_tensor * c1) { - assert(ggml_v1_can_mul_mat(b0, a)); + GGML_V1_ASSERT(ggml_v1_can_mul_mat(b0, a)); // TODO: more checks bool is_node = false; @@ -3214,7 +3916,7 @@ void ggml_v1_set_param( struct ggml_v1_tensor * tensor) { tensor->is_param = true; - assert(tensor->grad == NULL); + GGML_V1_ASSERT(tensor->grad == NULL); tensor->grad = ggml_v1_dup_tensor(ctx, tensor); } @@ -3224,9 +3926,9 @@ static void ggml_v1_compute_forward_dup_f16( const struct ggml_v1_compute_params * params, const struct ggml_v1_tensor * src0, struct ggml_v1_tensor * dst) { - assert(params->ith == 0); - assert(ggml_v1_is_contiguous(dst)); - assert(ggml_v1_nelements(dst) == ggml_v1_nelements(src0)); + GGML_V1_ASSERT(params->ith == 0); + GGML_V1_ASSERT(ggml_v1_is_contiguous(dst)); + GGML_V1_ASSERT(ggml_v1_nelements(dst) == ggml_v1_nelements(src0)); if (params->type == GGML_V1_TASK_INIT || params->type == GGML_V1_TASK_FINALIZE) { return; @@ -3441,6 +4143,8 @@ static void ggml_v1_compute_forward_dup( { ggml_v1_compute_forward_dup_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: @@ -3516,13 +4220,15 @@ static void ggml_v1_compute_forward_add( { ggml_v1_compute_forward_add_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3566,13 +4272,15 @@ static void ggml_v1_compute_forward_sub( { ggml_v1_compute_forward_sub_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3616,13 +4324,15 @@ static void ggml_v1_compute_forward_mul( { ggml_v1_compute_forward_mul_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3666,13 +4376,15 @@ static void ggml_v1_compute_forward_div( { ggml_v1_compute_forward_div_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3712,13 +4424,15 @@ static void ggml_v1_compute_forward_sqr( { ggml_v1_compute_forward_sqr_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3758,13 +4472,15 @@ static void ggml_v1_compute_forward_sqrt( { ggml_v1_compute_forward_sqrt_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3814,13 +4530,15 @@ static void ggml_v1_compute_forward_sum( { ggml_v1_compute_forward_sum_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3889,13 +4607,15 @@ static void ggml_v1_compute_forward_mean( { ggml_v1_compute_forward_mean_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3951,13 +4671,15 @@ static void ggml_v1_compute_forward_repeat( { ggml_v1_compute_forward_repeat_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -3997,13 +4719,15 @@ static void ggml_v1_compute_forward_abs( { ggml_v1_compute_forward_abs_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -4043,13 +4767,15 @@ static void ggml_v1_compute_forward_sgn( { ggml_v1_compute_forward_sgn_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -4089,13 +4815,15 @@ static void ggml_v1_compute_forward_neg( { ggml_v1_compute_forward_neg_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -4135,13 +4863,15 @@ static void ggml_v1_compute_forward_step( { ggml_v1_compute_forward_step_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -4181,13 +4911,15 @@ static void ggml_v1_compute_forward_relu( { ggml_v1_compute_forward_relu_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -4244,15 +4976,19 @@ static void ggml_v1_compute_forward_gelu( { ggml_v1_compute_forward_gelu_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } + + //printf("XXXXXXXX gelu\n"); } // ggml_v1_compute_forward_norm @@ -4326,13 +5062,15 @@ static void ggml_v1_compute_forward_norm( { ggml_v1_compute_forward_norm_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -4354,9 +5092,8 @@ static bool ggml_v1_compute_forward_mul_mat_use_blas( const int ne1 = dst->ne[1]; // TODO: find the optimal values for these - if (ggml_v1_is_contiguous(src0) && ggml_v1_is_contiguous(src1) && ( - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32) - )) { + if (ggml_v1_is_contiguous(src0) && + ggml_v1_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); return true; } @@ -4749,7 +5486,7 @@ static void ggml_v1_compute_forward_mul_mat_f16_f32( } } - //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_v1_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_v1_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ return; } @@ -4916,12 +5653,620 @@ static void ggml_v1_compute_forward_mul_mat_f16_f32( //} } +static void ggml_v1_compute_forward_mul_mat_q4_0_f32( + const struct ggml_v1_compute_params * params, + const struct ggml_v1_tensor * src0, + const struct ggml_v1_tensor * src1, + struct ggml_v1_tensor * dst) { + int64_t t0 = ggml_v1_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_V1_ASSERT(ne02 == ne12); + GGML_V1_ASSERT(ne03 == ne13); + GGML_V1_ASSERT(ne2 == ne12); + GGML_V1_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_V1_ASSERT(nb00 == (int) GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0] || nb01 == (int) GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0]); + + // dst cannot be transposed or permuted + GGML_V1_ASSERT(nb0 == sizeof(float)); + GGML_V1_ASSERT(nb0 <= nb1); + GGML_V1_ASSERT(nb1 <= nb2); + GGML_V1_ASSERT(nb2 <= nb3); + + GGML_V1_ASSERT(ne0 == ne01); + GGML_V1_ASSERT(ne1 == ne11); + GGML_V1_ASSERT(ne2 == ne02); + GGML_V1_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_V1_USE_ACCELERATE) || defined(GGML_V1_USE_OPENBLAS) + if (ggml_v1_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_V1_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_V1_TASK_INIT) { + return; + } + + if (params->type == GGML_V1_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + //for (int i00 = 0; i00 < ne00; ++i00) { + // wdata[id++] = GGML_V1_FP16_TO_FP32(*(ggml_v1_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + //} + dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + + { +#if 1 + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne00, + x, ne00, + 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, + 0.0f, d, ne01); +#endif + } + } + } + + /*printf("CBLAS Q4_0 = %f ms, %d x %d x %d x %d\n", (ggml_v1_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ + + return; + } +#endif + + if (params->type == GGML_V1_TASK_INIT) { + //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth); + if (nb01 >= nb00) { + char * wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + //for (int i10 = 0; i10 < ne10; ++i10) { + // wdata[id++] = GGML_V1_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + //} + quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += (ne10*GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0])/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_0]; + } + } + } + + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_V1_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + float * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + ggml_v1_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); + + for (int k = 1; k < nth; k++) { + ggml_v1_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); + } + + return; + } + + if (nb01 >= nb00) { + // TODO: do not support transposed src1 + + // parallelize by src0 rows using ggml_v1_vec_dot_q4_0 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + void * wdata = params->wdata; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0])/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_0]); + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int ic = 0; ic < ne11; ++ic) { + ggml_v1_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0])/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_0]))); + } + } + } else { + //printf("AAAAA ith = %d, nth = %d\n", ith, nth); + // parallelize by src1 columns using ggml_v1_vec_mad_q4_0 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + float * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + + ggml_v1_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val); + } + } + } + } + } + + //int64_t t1 = ggml_v1_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_v1_compute_forward_mul_mat_q4_1_f32( + const struct ggml_v1_compute_params * params, + const struct ggml_v1_tensor * src0, + const struct ggml_v1_tensor * src1, + struct ggml_v1_tensor * dst) { + int64_t t0 = ggml_v1_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_V1_ASSERT(ne02 == ne12); + GGML_V1_ASSERT(ne03 == ne13); + GGML_V1_ASSERT(ne2 == ne12); + GGML_V1_ASSERT(ne3 == ne13); + + // TODO: we don't support permuted src0 + GGML_V1_ASSERT(nb00 == (int) GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1] || nb01 == (int) GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1]); + + // dst cannot be transposed or permuted + GGML_V1_ASSERT(nb0 == sizeof(float)); + GGML_V1_ASSERT(nb0 <= nb1); + GGML_V1_ASSERT(nb1 <= nb2); + GGML_V1_ASSERT(nb2 <= nb3); + + GGML_V1_ASSERT(ne0 == ne01); + GGML_V1_ASSERT(ne1 == ne11); + GGML_V1_ASSERT(ne2 == ne02); + GGML_V1_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + // + // nb00 < nb01 - src0 is transposed + // compute by src0 columns + +#if defined(GGML_V1_USE_ACCELERATE) || defined(GGML_V1_USE_OPENBLAS) + if (ggml_v1_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_V1_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) { + return; + } + + if (params->type == GGML_V1_TASK_INIT) { + return; + } + + if (params->type == GGML_V1_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + //for (int i00 = 0; i00 < ne00; ++i00) { + // wdata[id++] = GGML_V1_FP16_TO_FP32(*(ggml_v1_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + //} + dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + id += ne00; + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + + { +#if 1 + // zT = y * xT + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne00, + x, ne00, + 0.0f, d, ne01); +#else + // zT = (xT * y)T + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, + ne01, ne11, ne10, + 1.0f, x, ne00, + y, ne00, + 0.0f, d, ne01); +#endif + } + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_v1_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_V1_TASK_INIT) { + //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth); + if (nb01 >= nb00) { + char * wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + //for (int i10 = 0; i10 < ne10; ++i10) { + // wdata[id++] = GGML_V1_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + //} + quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += (ne10*GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1])/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_1]; + } + } + } + + return; + } + + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + return; + } + + if (params->type == GGML_V1_TASK_FINALIZE) { + if (nb01 >= nb00) { + return; + } + + float * const wdata = params->wdata; + + // cols per thread + const int dc = (ne + nth - 1)/nth; + + // col range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, ne); + + ggml_v1_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); + + for (int k = 1; k < nth; k++) { + ggml_v1_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); + } + + return; + } + + if (nb01 >= nb00) { + // TODO: do not support transposed src1 + + // parallelize by src0 rows using ggml_v1_vec_dot_q4_1 + + // total rows in src0 + const int nr = ne01*ne02*ne03; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + void * wdata = params->wdata; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int i13 = i03; + const int i12 = i02; + + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1])/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_1]); + + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + for (int ic = 0; ic < ne11; ++ic) { + ggml_v1_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1])/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_1]))); + } + } + } else { + //printf("AAAAA ith = %d, nth = %d\n", ith, nth); + // parallelize by src1 columns using ggml_v1_vec_mad_q4_1 + // each thread has its own work data + // during FINALIZE we accumulate all work data into dst + + // total columns in src1 + const int nc = ne10; + + // columns per thread + const int dc = (nc + nth - 1)/nth; + + // column range for this thread + const int ic0 = dc*ith; + const int ic1 = MIN(ic0 + dc, nc); + + // work data for thread + const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; + float * const wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + // dst indices + const int i1 = i11; + const int i2 = i12; + const int i3 = i13; + + float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + + for (int ic = ic0; ic < ic1; ++ic) { + // src1 indices + const int i10 = ic; + + // src0 indices + const int i03 = i13; + const int i02 = i12; + const int i00 = ic; + + assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + + void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); + float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + + ggml_v1_vec_mad_q4_1(ne01, dst_row, src0_col, src1_val); + } + } + } + } + } + + //int64_t t1 = ggml_v1_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + static void ggml_v1_compute_forward_mul_mat( const struct ggml_v1_compute_params * params, const struct ggml_v1_tensor * src0, const struct ggml_v1_tensor * src1, struct ggml_v1_tensor * dst) { switch (src0->type) { + case GGML_V1_TYPE_Q4_0: + { + ggml_v1_compute_forward_mul_mat_q4_0_f32(params, src0, src1, dst); + } break; + case GGML_V1_TYPE_Q4_1: + { + ggml_v1_compute_forward_mul_mat_q4_1_f32(params, src0, src1, dst); + } break; case GGML_V1_TYPE_F16: { ggml_v1_compute_forward_mul_mat_f16_f32(params, src0, src1, dst); @@ -4935,9 +6280,37 @@ static void ggml_v1_compute_forward_mul_mat( case GGML_V1_TYPE_I32: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } + +#if 0 + if (src0->type == GGML_V1_TYPE_F16 || src0->type == GGML_V1_TYPE_Q4_1) { + static int first = 8; + printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); + printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); + printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + if (first) { + --first; + } else { + for (int k = 0; k < dst->ne[1]; ++k) { + for (int j = 0; j < dst->ne[0]/16; ++j) { + for (int i = 0; i < 16; ++i) { + printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + } + printf("\n"); + } + printf("\n"); + } + printf("\n"); + exit(0); + } + } else { + printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]); + printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]); + printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + } +#endif } // ggml_v1_compute_forward_scale @@ -4987,13 +6360,15 @@ static void ggml_v1_compute_forward_scale( { ggml_v1_compute_forward_scale_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -5051,6 +6426,60 @@ static void ggml_v1_compute_forward_transpose( // ggml_v1_compute_forward_get_rows +static void ggml_v1_compute_forward_get_rows_q4_0( + const struct ggml_v1_compute_params * params, + const struct ggml_v1_tensor * src0, + const struct ggml_v1_tensor * src1, + struct ggml_v1_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_V1_TASK_INIT || params->type == GGML_V1_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_v1_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0]); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + dequantize_row_q4_0( + (const void *) ((char *) src0->data + r*src0->nb[1]), + (float *) ((char *) dst->data + i*dst->nb[1]), nc); + } +} + +static void ggml_v1_compute_forward_get_rows_q4_1( + const struct ggml_v1_compute_params * params, + const struct ggml_v1_tensor * src0, + const struct ggml_v1_tensor * src1, + struct ggml_v1_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_V1_TASK_INIT || params->type == GGML_V1_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_v1_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1]); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + dequantize_row_q4_1( + (const void *) ((char *) src0->data + r*src0->nb[1]), + (float *) ((char *) dst->data + i*dst->nb[1]), nc); + } +} + static void ggml_v1_compute_forward_get_rows_f16( const struct ggml_v1_compute_params * params, const struct ggml_v1_tensor * src0, @@ -5112,6 +6541,14 @@ static void ggml_v1_compute_forward_get_rows( const struct ggml_v1_tensor * src1, struct ggml_v1_tensor * dst) { switch (src0->type) { + case GGML_V1_TYPE_Q4_0: + { + ggml_v1_compute_forward_get_rows_q4_0(params, src0, src1, dst); + } break; + case GGML_V1_TYPE_Q4_1: + { + ggml_v1_compute_forward_get_rows_q4_1(params, src0, src1, dst); + } break; case GGML_V1_TYPE_F16: { ggml_v1_compute_forward_get_rows_f16(params, src0, src1, dst); @@ -5125,9 +6562,27 @@ static void ggml_v1_compute_forward_get_rows( case GGML_V1_TYPE_I32: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} } // ggml_v1_compute_forward_diag_mask_inf @@ -5178,13 +6633,15 @@ static void ggml_v1_compute_forward_diag_mask_inf( { ggml_v1_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -5223,6 +6680,7 @@ static void ggml_v1_compute_forward_soft_max_f32( #ifndef NDEBUG for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); assert(!isnan(p[i])); } #endif @@ -5269,13 +6727,15 @@ static void ggml_v1_compute_forward_soft_max( { ggml_v1_compute_forward_soft_max_f32(params, src0, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -5339,23 +6799,84 @@ static void ggml_v1_compute_forward_rope_f32( } } +static void ggml_v1_compute_forward_rope_f16( + const struct ggml_v1_compute_params * params, + const struct ggml_v1_tensor * src0, + const struct ggml_v1_tensor * src1, + struct ggml_v1_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_V1_TYPE_I32); + assert(ggml_v1_nelements(src1) == 3); + + if (params->type == GGML_V1_TASK_INIT || params->type == GGML_V1_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) src1->data)[0]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + //const int ne0 = src0->ne[0]; + const int ne1 = src0->ne[1]; + const int ne2 = src0->ne[2]; + const int ne3 = src0->ne[3]; + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + const int nb3 = src0->nb[3]; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(ggml_v1_fp16_t)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int p = (mode == 0 ? n_past + i2 : i2); + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < n_dims; i0 += 2) { + const double theta = pow(10000.0, ((double)-i0)/n_dims); + + const double cos_theta = cos(p*theta); + const double sin_theta = sin(p*theta); + + const ggml_v1_fp16_t * const src = (ggml_v1_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + ggml_v1_fp16_t * dst_data = (ggml_v1_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + double x0 = ggml_v1_fp16_to_fp32(src[0]); + double x1 = ggml_v1_fp16_to_fp32(src[1]); + + dst_data[0] = ggml_v1_fp32_to_fp16(x0*cos_theta - x1*sin_theta); + dst_data[1] = ggml_v1_fp32_to_fp16(x0*sin_theta + x1*cos_theta); + } + } + } + } +} + static void ggml_v1_compute_forward_rope( const struct ggml_v1_compute_params * params, const struct ggml_v1_tensor * src0, const struct ggml_v1_tensor * src1, struct ggml_v1_tensor * dst) { switch (src0->type) { + case GGML_V1_TYPE_F16: + { + ggml_v1_compute_forward_rope_f16(params, src0, src1, dst); + } break; case GGML_V1_TYPE_F32: { ggml_v1_compute_forward_rope_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: - case GGML_V1_TYPE_F16: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -5616,6 +7137,8 @@ static void ggml_v1_compute_forward_conv_1d_1s( { ggml_v1_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: @@ -5882,6 +7405,8 @@ static void ggml_v1_compute_forward_conv_1d_2s( { ggml_v1_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: @@ -6365,12 +7890,14 @@ static void ggml_v1_compute_forward_flash_attn( { ggml_v1_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -6574,12 +8101,14 @@ static void ggml_v1_compute_forward_flash_ff( { GGML_V1_ASSERT(false); // TODO } break; + case GGML_V1_TYPE_Q4_0: + case GGML_V1_TYPE_Q4_1: case GGML_V1_TYPE_I8: case GGML_V1_TYPE_I16: case GGML_V1_TYPE_I32: case GGML_V1_TYPE_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } @@ -6587,7 +8116,7 @@ static void ggml_v1_compute_forward_flash_ff( ///////////////////////////////// static void ggml_v1_compute_forward(struct ggml_v1_compute_params * params, struct ggml_v1_tensor * tensor) { - assert(params); + GGML_V1_ASSERT(params); switch (tensor->op) { case GGML_V1_OP_DUP: @@ -6835,7 +8364,7 @@ static void ggml_v1_compute_backward(struct ggml_v1_context * ctx, struct ggml_v } break; case GGML_V1_OP_MEAN: { - assert(false); // TODO: implement + GGML_V1_ASSERT(false); // TODO: implement } break; case GGML_V1_OP_REPEAT: { @@ -6890,17 +8419,17 @@ static void ggml_v1_compute_backward(struct ggml_v1_context * ctx, struct ggml_v } break; case GGML_V1_OP_GELU: { - assert(false); // TODO: not implemented + GGML_V1_ASSERT(false); // TODO: not implemented } break; case GGML_V1_OP_NORM: { - assert(false); // TODO: not implemented + GGML_V1_ASSERT(false); // TODO: not implemented } break; case GGML_V1_OP_MUL_MAT: { if (src0->grad) { // TODO: this requires outer product - ggml_v1_out_prod(ctx, src1, tensor->grad); - assert(false); + GGML_V1_ASSERT(false); } if (src1->grad) { src1->grad = @@ -7016,12 +8545,12 @@ static void ggml_v1_visit_parents(struct ggml_v1_cgraph * cgraph, struct ggml_v1 if (node->op == GGML_V1_OP_NONE && node->grad == NULL) { // reached a leaf node, not part of the gradient graph (e.g. a constant) - assert(cgraph->n_leafs < GGML_V1_MAX_NODES); + GGML_V1_ASSERT(cgraph->n_leafs < GGML_V1_MAX_NODES); cgraph->leafs[cgraph->n_leafs] = node; cgraph->n_leafs++; } else { - assert(cgraph->n_nodes < GGML_V1_MAX_NODES); + GGML_V1_ASSERT(cgraph->n_nodes < GGML_V1_MAX_NODES); cgraph->nodes[cgraph->n_nodes] = node; cgraph->grads[cgraph->n_nodes] = node->grad; @@ -7045,7 +8574,7 @@ static void ggml_v1_build_forward_impl(struct ggml_v1_cgraph * cgraph, struct gg if (n_new > 0) { // the last added node should always be starting point - assert(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + GGML_V1_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); } } @@ -7076,7 +8605,7 @@ struct ggml_v1_cgraph ggml_v1_build_forward(struct ggml_v1_tensor * tensor) { struct ggml_v1_cgraph ggml_v1_build_backward(struct ggml_v1_context * ctx, struct ggml_v1_cgraph * gf, bool keep) { struct ggml_v1_cgraph result = *gf; - assert(gf->n_nodes > 0); + GGML_V1_ASSERT(gf->n_nodes > 0); // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph if (keep) { @@ -7275,7 +8804,7 @@ void ggml_v1_graph_compute(struct ggml_v1_context * ctx, struct ggml_v1_cgraph * }; int rc = ggml_v1_thread_create(&workers[j].thrd, NULL, ggml_v1_graph_compute_thread, &workers[j]); - assert(rc == 0); + GGML_V1_ASSERT(rc == 0); UNUSED(rc); } } @@ -7337,6 +8866,7 @@ void ggml_v1_graph_compute(struct ggml_v1_context * ctx, struct ggml_v1_cgraph * // TODO: better way to determine if the matrix is transposed if (node->src0->nb[1] < node->src0->nb[0]) { cur = ggml_v1_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) + // TODO: overestimated by factor of x2 for FP16 } else { if (node->src0->type == GGML_V1_TYPE_F16 && node->src1->type == GGML_V1_TYPE_F32) { @@ -7344,19 +8874,43 @@ void ggml_v1_graph_compute(struct ggml_v1_context * ctx, struct ggml_v1_cgraph * if (ggml_v1_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning - cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); + cur = GGML_V1_TYPE_SIZE[GGML_V1_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); //printf("cur = %zu\n", cur); } else { - cur = sizeof(ggml_v1_fp16_t)*ggml_v1_nelements(node->src1); + cur = GGML_V1_TYPE_SIZE[GGML_V1_TYPE_F16]*ggml_v1_nelements(node->src1); } #else - cur = sizeof(ggml_v1_fp16_t)*ggml_v1_nelements(node->src1); + cur = GGML_V1_TYPE_SIZE[GGML_V1_TYPE_F16]*ggml_v1_nelements(node->src1); #endif } else if (node->src0->type == GGML_V1_TYPE_F32 && node->src1->type == GGML_V1_TYPE_F32) { cur = 0; + } else if (node->src0->type == GGML_V1_TYPE_Q4_0 && + node->src1->type == GGML_V1_TYPE_F32) { +#if defined(GGML_V1_USE_ACCELERATE) || defined(GGML_V1_USE_OPENBLAS) + if (ggml_v1_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = GGML_V1_TYPE_SIZE[GGML_V1_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = (GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0]*ggml_v1_nelements(node->src1))/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_0]; + } +#else + cur = (GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_0]*ggml_v1_nelements(node->src1))/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_0]; +#endif + } else if (node->src0->type == GGML_V1_TYPE_Q4_1 && + node->src1->type == GGML_V1_TYPE_F32) { +#if defined(GGML_V1_USE_ACCELERATE) || defined(GGML_V1_USE_OPENBLAS) + if (ggml_v1_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = GGML_V1_TYPE_SIZE[GGML_V1_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = (GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1]*ggml_v1_nelements(node->src1))/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_1]; + } +#else + cur = (GGML_V1_TYPE_SIZE[GGML_V1_TYPE_Q4_1]*ggml_v1_nelements(node->src1))/GGML_V1_BLCK_SIZE[GGML_V1_TYPE_Q4_1]; +#endif } else { GGML_V1_ASSERT(false); } @@ -7460,13 +9014,13 @@ void ggml_v1_graph_compute(struct ggml_v1_context * ctx, struct ggml_v1_cgraph * } break; case GGML_V1_OP_COUNT: { - assert(false); + GGML_V1_ASSERT(false); } break; } } if (cgraph->work != NULL && work_size > cgraph->work_size) { - assert(false); // TODO: better handling + GGML_V1_ASSERT(false); // TODO: better handling } if (work_size > 0 && cgraph->work == NULL) { @@ -7632,7 +9186,7 @@ void ggml_v1_graph_compute(struct ggml_v1_context * ctx, struct ggml_v1_cgraph * for (int j = 0; j < n_threads - 1; j++) { int rc = ggml_v1_thread_join(workers[j].thrd, NULL); - assert(rc == 0); + GGML_V1_ASSERT(rc == 0); UNUSED(rc); } @@ -7739,7 +9293,7 @@ void ggml_v1_graph_dump_dot(const struct ggml_v1_cgraph * gb, const struct ggml_ char color[16]; FILE * fp = fopen(filename, "w"); - assert(fp); + GGML_V1_ASSERT(fp); fprintf(fp, "digraph G {\n"); fprintf(fp, " newrank = true;\n"); @@ -7897,7 +9451,7 @@ static enum ggml_v1_opt_result ggml_v1_opt_adam( struct ggml_v1_tensor * f, struct ggml_v1_cgraph * gf, struct ggml_v1_cgraph * gb) { - assert(ggml_v1_is_scalar(f)); + GGML_V1_ASSERT(ggml_v1_is_scalar(f)); gf->n_threads = params.n_threads; gb->n_threads = params.n_threads; @@ -7911,7 +9465,7 @@ static enum ggml_v1_opt_result ggml_v1_opt_adam( if (gf->nodes[i]->is_param) { GGML_V1_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - assert(np < GGML_V1_MAX_PARAMS); + GGML_V1_ASSERT(np < GGML_V1_MAX_PARAMS); ps[np++] = gf->nodes[i]; nx += ggml_v1_nelements(gf->nodes[i]); @@ -8211,7 +9765,7 @@ static enum ggml_v1_opt_result ggml_v1_opt_lbfgs( if (gf->nodes[i]->is_param) { GGML_V1_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - assert(np < GGML_V1_MAX_PARAMS); + GGML_V1_ASSERT(np < GGML_V1_MAX_PARAMS); ps[np++] = gf->nodes[i]; nx += ggml_v1_nelements(gf->nodes[i]); diff --git a/otherarch/ggml_v1.h b/otherarch/ggml_v1.h index 0debe0925..f333b580e 100644 --- a/otherarch/ggml_v1.h +++ b/otherarch/ggml_v1.h @@ -198,6 +198,8 @@ struct ggml_v1_object; struct ggml_v1_context; enum ggml_v1_type { + GGML_V1_TYPE_Q4_0, + GGML_V1_TYPE_Q4_1, GGML_V1_TYPE_I8, GGML_V1_TYPE_I16, GGML_V1_TYPE_I32, @@ -326,7 +328,10 @@ void ggml_v1_print_objects(const struct ggml_v1_context * ctx); int ggml_v1_nelements(const struct ggml_v1_tensor * tensor); size_t ggml_v1_nbytes (const struct ggml_v1_tensor * tensor); -size_t ggml_v1_type_size (enum ggml_v1_type type); +int ggml_v1_blck_size (enum ggml_v1_type type); +size_t ggml_v1_type_size (enum ggml_v1_type type); // size in bytes for all elements in a block +float ggml_v1_type_sizef(enum ggml_v1_type type); // ggml_v1_type_size()/ggml_v1_blck_size() as float + size_t ggml_v1_element_size(const struct ggml_v1_tensor * tensor); struct ggml_v1_context * ggml_v1_init(struct ggml_v1_init_params params); diff --git a/otherarch/gptj_v1.cpp b/otherarch/gptj_v1.cpp index 93e7d0684..805d807bd 100644 --- a/otherarch/gptj_v1.cpp +++ b/otherarch/gptj_v1.cpp @@ -17,13 +17,15 @@ // load the model's weights from a file -bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gpt_vocab & vocab) { +ModelLoadResult legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gpt_vocab & vocab, FileFormat file_format) { printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + bool super_old_format = (file_format==FileFormat::GPTJ1); + auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); - return false; + return ModelLoadResult::FAIL; } // verify magic @@ -32,7 +34,7 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp fin.read((char *) &magic, sizeof(magic)); if (magic != 0x67676d6c) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); - return false; + return ModelLoadResult::FAIL; } } @@ -65,7 +67,7 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp if (n_vocab != model.hparams.n_vocab) { fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); - return false; + return ModelLoadResult::FAIL; } std::string word; @@ -81,9 +83,23 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp } } - // for the big tensors, we have the option to store the data in 16-bit floats + // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - const ggml_v1_type wtype = model.hparams.f16 ? GGML_V1_TYPE_F16 : GGML_V1_TYPE_F32; + ggml_v1_type wtype = GGML_V1_TYPE_COUNT; + switch (model.hparams.f16) { + case 0: wtype = GGML_V1_TYPE_F32; break; + case 1: wtype = GGML_V1_TYPE_F16; break; + case 2: wtype = GGML_V1_TYPE_Q4_0; break; + case 3: wtype = GGML_V1_TYPE_Q4_1; break; + default: + { + fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", + __func__, fname.c_str(), model.hparams.f16); + return ModelLoadResult::FAIL; + } + } + + const ggml_v1_type wtype2 = GGML_V1_TYPE_F32; auto & ctx = model.ctx; @@ -97,31 +113,31 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; - ctx_size += n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32); // ln_f_g - ctx_size += n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32); // ln_f_b + ctx_size += n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32); // ln_f_g + ctx_size += n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32); // ln_f_b - ctx_size += n_embd*n_vocab*ggml_v1_type_size(wtype); // wte + ctx_size += n_embd*n_vocab*ggml_v1_type_sizef(wtype); // wte - ctx_size += n_embd*n_vocab*ggml_v1_type_size(wtype); // lmh_g - ctx_size += n_vocab*ggml_v1_type_size(GGML_V1_TYPE_F32); // lmh_b + ctx_size += n_embd*n_vocab*ggml_v1_type_sizef(wtype); // lmh_g + ctx_size += n_vocab*ggml_v1_type_sizef(GGML_V1_TYPE_F32); // lmh_b - ctx_size += n_layer*(n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32)); // ln_1_g - ctx_size += n_layer*(n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32)); // ln_1_b + ctx_size += n_layer*(n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32)); // ln_1_g + ctx_size += n_layer*(n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32)); // ln_1_b - ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_size(wtype)); // c_attn_q_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_size(wtype)); // c_attn_k_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_size(wtype)); // c_attn_v_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_sizef(wtype)); // c_attn_q_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_sizef(wtype)); // c_attn_k_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_sizef(wtype)); // c_attn_v_proj_w - ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_size(wtype)); // c_attn_proj_w + ctx_size += n_layer*(n_embd*n_embd*ggml_v1_type_sizef(wtype)); // c_attn_proj_w - ctx_size += n_layer*(4*n_embd*n_embd*ggml_v1_type_size(wtype)); // c_mlp_fc_w - ctx_size += n_layer*( 4*n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32)); // c_mlp_fc_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_v1_type_sizef(wtype)); // c_mlp_fc_w + ctx_size += n_layer*( 4*n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32)); // c_mlp_fc_b - ctx_size += n_layer*(4*n_embd*n_embd*ggml_v1_type_size(wtype)); // c_mlp_proj_w_trans - ctx_size += n_layer*( n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32)); // c_mlp_proj_b + ctx_size += n_layer*(4*n_embd*n_embd*ggml_v1_type_sizef(wtype)); // c_mlp_proj_w_trans + ctx_size += n_layer*( n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32)); // c_mlp_proj_b - ctx_size += n_ctx*n_layer*n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32); // memory_k - ctx_size += n_ctx*n_layer*n_embd*ggml_v1_type_size(GGML_V1_TYPE_F32); // memory_v + ctx_size += n_ctx*n_layer*n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32); // memory_k + ctx_size += n_ctx*n_layer*n_embd*ggml_v1_type_sizef(GGML_V1_TYPE_F32); // memory_v ctx_size += (5 + 10*n_layer)*256; // object overhead @@ -138,7 +154,7 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp model.ctx = ggml_v1_init(params); if (!model.ctx) { fprintf(stderr, "%s: ggml_v1_init() failed\n", __func__); - return false; + return ModelLoadResult::FAIL; } } @@ -181,8 +197,15 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp layer.c_attn_v_proj_w = ggml_v1_new_tensor_2d(ctx, wtype, n_embd, n_embd); layer.c_attn_proj_w = ggml_v1_new_tensor_2d(ctx, wtype, n_embd, n_embd); - - layer.c_mlp_fc_w = ggml_v1_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); + + if(super_old_format) + { + layer.c_mlp_fc_w = ggml_v1_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); + } + else + { + layer.c_mlp_fc_w = ggml_v1_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); + } layer.c_mlp_fc_b = ggml_v1_new_tensor_1d(ctx, GGML_V1_TYPE_F32, 4*n_embd); layer.c_mlp_proj_w_trans = ggml_v1_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); @@ -257,27 +280,55 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp if (model.tensors.find(name.data()) == model.tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); - return false; + return ModelLoadResult::FAIL; } auto tensor = model.tensors[name.data()]; if (ggml_v1_nelements(tensor) != nelements) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; + return ModelLoadResult::FAIL; } - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); - return false; + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) + { + //test for transposition and retry older loader + if(tensor->ne[0]==ne[1] && tensor->ne[1]==ne[0] && should_transpose_layer(name)) + { + printf("\nFound a transposed tensor. This could be an older model. Retrying load..."); + ggml_v1_free(ctx); + return ModelLoadResult::RETRY_LOAD; + } + else + { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); + return ModelLoadResult::FAIL; + } } - const size_t bpe = tensor->type == GGML_V1_TYPE_I8 ? 1 : (ftype == 0) ? sizeof(float) : sizeof(ggml_v1_fp16_t); + if (0) { + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_v1_nbytes(tensor)/1024.0/1024.0, ggml_v1_nbytes(tensor)); + } - if (nelements*bpe != ggml_v1_nbytes(tensor)) { + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_v1_type_size(GGML_V1_TYPE_F32); break; + case 1: bpe = ggml_v1_type_size(GGML_V1_TYPE_F16); break; + case 2: bpe = ggml_v1_type_size(GGML_V1_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_v1_type_size(GGML_V1_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return ModelLoadResult::FAIL; + } + }; + + if ((nelements*bpe)/ggml_v1_blck_size(tensor->type) != ggml_v1_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_v1_nbytes(tensor), nelements*bpe); - return false; + return ModelLoadResult::FAIL; } fin.read(reinterpret_cast(tensor->data), ggml_v1_nbytes(tensor)); @@ -297,7 +348,7 @@ bool legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gp fin.close(); - return true; + return ModelLoadResult::SUCCESS; } // evaluate the transformer @@ -316,7 +367,10 @@ bool legacy_gptj_eval( const int n_past, const std::vector & embd_inp, std::vector & embd_w, - size_t & mem_per_token) { + size_t & mem_per_token, + FileFormat file_format) { + + bool super_old_format = (file_format==FileFormat::GPTJ1); const int N = embd_inp.size(); const auto & hparams = model.hparams; @@ -379,9 +433,21 @@ bool legacy_gptj_eval( // self-attention { - struct ggml_v1_tensor * Qcur = ggml_v1_mul_mat(ctx0, ggml_v1_transpose(ctx0, model.layers[il].c_attn_q_proj_w), cur); - struct ggml_v1_tensor * Kcur = ggml_v1_mul_mat(ctx0, ggml_v1_transpose(ctx0, model.layers[il].c_attn_k_proj_w), cur); - struct ggml_v1_tensor * Vcur = ggml_v1_mul_mat(ctx0, ggml_v1_transpose(ctx0, model.layers[il].c_attn_v_proj_w), cur); + struct ggml_v1_tensor * Qcur; + struct ggml_v1_tensor * Kcur; + struct ggml_v1_tensor * Vcur; + if(super_old_format) + { + Qcur = ggml_v1_mul_mat(ctx0, ggml_v1_transpose(ctx0, model.layers[il].c_attn_q_proj_w), cur); + Kcur = ggml_v1_mul_mat(ctx0, ggml_v1_transpose(ctx0, model.layers[il].c_attn_k_proj_w), cur); + Vcur = ggml_v1_mul_mat(ctx0, ggml_v1_transpose(ctx0, model.layers[il].c_attn_v_proj_w), cur); + } + else + { + Qcur = ggml_v1_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur); + Kcur = ggml_v1_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur); + Vcur = ggml_v1_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur); + } // store key and value to memory if (N >= 1) { @@ -448,9 +514,18 @@ bool legacy_gptj_eval( ggml_v1_new_tensor_2d(ctx0, GGML_V1_TYPE_F32, n_embd, N)); // projection (no bias) - cur = ggml_v1_mul_mat(ctx0, - ggml_v1_transpose(ctx0, model.layers[il].c_attn_proj_w), - cur); + if(super_old_format) + { + cur = ggml_v1_mul_mat(ctx0, + ggml_v1_transpose(ctx0, model.layers[il].c_attn_proj_w), + cur); + } + else + { + cur = ggml_v1_mul_mat(ctx0, + model.layers[il].c_attn_proj_w, + cur); + } } struct ggml_v1_tensor * inpFF = cur; @@ -459,9 +534,16 @@ bool legacy_gptj_eval( // this is independent of the self-attention result, so it could be done in parallel to the self-attention { // note here we pass inpSA instead of cur - cur = ggml_v1_mul_mat(ctx0, - ggml_v1_transpose(ctx0, model.layers[il].c_mlp_fc_w), - inpSA); + if(super_old_format) + { + cur = ggml_v1_mul_mat(ctx0, + ggml_v1_transpose(ctx0, model.layers[il].c_mlp_fc_w), + inpSA); + }else{ + cur = ggml_v1_mul_mat(ctx0, + model.layers[il].c_mlp_fc_w, + inpSA); + } cur = ggml_v1_add(ctx0, ggml_v1_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), @@ -538,145 +620,3 @@ bool legacy_gptj_eval( return true; } -// int main(int argc, char ** argv) { -// ggml_v1_time_init(); -// const int64_t t_main_start_us = ggml_v1_time_us(); - -// gpt_params params; -// params.model = "models/gpt-j-6B/ggml-model.bin"; - -// if (utils_gpt_params_parse(argc, argv, params) == false) { -// return 1; -// } - -// if (params.seed < 0) { -// params.seed = time(NULL); -// } - -// printf("%s: seed = %d\n", __func__, params.seed); - -// std::mt19937 rng(params.seed); -// if (params.prompt.empty()) { -// if( !isatty(STDIN_FILENO) ){ -// std::string line; -// while( std::getline(std::cin, line) ){ -// params.prompt = params.prompt + "\n" + line; -// } -// } else { -// params.prompt = utils_gpt_random_prompt(rng); -// } -// } - -// int64_t t_load_us = 0; - -// gpt_vocab vocab; -// gptj_model_v1 model; - -// // load the model -// { -// const int64_t t_start_us = ggml_v1_time_us(); - -// if (!legacy_gptj_model_load(params.model, model, vocab)) { -// fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); -// return 1; -// } - -// t_load_us = ggml_v1_time_us() - t_start_us; -// } - -// int n_past = 0; - -// int64_t t_sample_us = 0; -// int64_t t_predict_us = 0; - -// std::vector logits; - -// // tokenize the prompt -// std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); - -// params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); - -// printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); -// printf("\n"); - -// std::vector embd; - -// // determine the required inference memory per token: -// size_t mem_per_token = 0; -// legacy_gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - -// for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { -// // predict -// if (embd.size() > 0) { -// const int64_t t_start_us = ggml_v1_time_us(); - -// if (!legacy_gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { -// printf("Failed to predict\n"); -// return 1; -// } - -// t_predict_us += ggml_v1_time_us() - t_start_us; -// } - -// n_past += embd.size(); -// embd.clear(); - -// if (i >= embd_inp.size()) { -// // sample next token -// const int top_k = params.top_k; -// const float top_p = params.top_p; -// const float temp = params.temp; - -// const int n_vocab = model.hparams.n_vocab; - -// gpt_vocab::id id = 0; - -// { -// const int64_t t_start_sample_us = ggml_v1_time_us(); - -// id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); - -// t_sample_us += ggml_v1_time_us() - t_start_sample_us; -// } - -// // add it to the context -// embd.push_back(id); -// } else { -// // if here, it means we are still processing the input prompt -// for (int k = i; k < embd_inp.size(); k++) { -// embd.push_back(embd_inp[k]); -// if (embd.size() > params.n_batch) { -// break; -// } -// } -// i += embd.size() - 1; -// } - -// // display text -// for (auto id : embd) { -// printf("%s", vocab.id_to_token[id].c_str()); -// } -// fflush(stdout); - -// // end of text token -// if (embd.back() == 50256) { -// break; -// } -// } - -// // report timing -// { -// const int64_t t_main_end_us = ggml_v1_time_us(); - -// printf("\n\n"); -// printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); -// printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); -// printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); -// printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); -// printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); -// } - -// ggml_v1_free(model.ctx); - -// return 0; -// } diff --git a/otherarch/gptj_v1_main.cpp b/otherarch/gptj_v1_main.cpp new file mode 100644 index 000000000..dd7f98591 --- /dev/null +++ b/otherarch/gptj_v1_main.cpp @@ -0,0 +1,145 @@ +#include "gptj_v1.cpp" + +int main(int argc, char ** argv) { + ggml_v1_time_init(); + const int64_t t_main_start_us = ggml_v1_time_us(); + + gpt_params params; + params.model = "models/gpt-j-6B/ggml-model.bin"; + + if (utils_gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.prompt.empty()) { + if( !isatty(STDIN_FILENO) ){ + std::string line; + while( std::getline(std::cin, line) ){ + params.prompt = params.prompt + "\n" + line; + } + } else { + params.prompt = utils_gpt_random_prompt(rng); + } + } + + int64_t t_load_us = 0; + + gpt_vocab vocab; + gptj_model_v1 model; + FileFormat file_format = FileFormat::GPTJ2; + + // load the model + { + const int64_t t_start_us = ggml_v1_time_us(); + + if (legacy_gptj_model_load(params.model, model, vocab, file_format)!=ModelLoadResult::SUCCESS) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_v1_time_us() - t_start_us; + } + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); + + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + printf("\n"); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + legacy_gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); + + for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_v1_time_us(); + + if (!legacy_gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token,file_format)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_v1_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_v1_time_us(); + + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + + t_sample_us += ggml_v1_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str()); + } + fflush(stdout); + + // end of text token + if (embd.back() == 50256) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_v1_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + ggml_v1_free(model.ctx); + + return 0; +} diff --git a/otherarch/gptj_v2.cpp b/otherarch/gptj_v2.cpp index 86a03106c..e4906366a 100644 --- a/otherarch/gptj_v2.cpp +++ b/otherarch/gptj_v2.cpp @@ -14,28 +14,18 @@ #include #include -bool should_transpose_layer(std::string name) -{ - - if(name.find(".mlp.fc_in.weight")!=std::string::npos || - name.find(".attn.out_proj.weight")!=std::string::npos || - name.find(".attn.q_proj.weight")!=std::string::npos || - name.find(".attn.k_proj.weight")!=std::string::npos || - name.find(".attn.v_proj.weight")!=std::string::npos) - { - return true; - } - return false; -} +#include "model_adapter.h" + + // load the model's weights from a file -bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab) { +ModelLoadResult gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab) { printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); - return false; + return ModelLoadResult::FAIL; } // verify magic @@ -44,7 +34,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & fin.read((char *) &magic, sizeof(magic)); if (magic != 0x67676d6c) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); - return false; + return ModelLoadResult::FAIL; } } @@ -77,7 +67,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & if (n_vocab != model.hparams.n_vocab) { fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); - return false; + return ModelLoadResult::FAIL; } std::string word; @@ -105,7 +95,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & { fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n", __func__, fname.c_str(), model.hparams.f16); - return false; + return ModelLoadResult::FAIL; } } @@ -165,7 +155,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & model.ctx = ggml_init(params); if (!model.ctx) { fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; + return ModelLoadResult::FAIL; } } @@ -284,20 +274,32 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & if (model.tensors.find(name.data()) == model.tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); - return false; + return ModelLoadResult::FAIL; } auto tensor = model.tensors[name.data()]; if (ggml_nelements(tensor) != nelements) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - return false; + return ModelLoadResult::FAIL; } if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", - __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); - return false; + + //test for transposition and retry older loader + if(tensor->ne[0]==ne[1] && tensor->ne[1]==ne[0] && should_transpose_layer(name)) + { + printf("\nFound a transposed tensor. This could be an older model. Retrying load..."); + ggml_free(ctx); + return ModelLoadResult::RETRY_LOAD; + } + else + { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", + __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]); + return ModelLoadResult::FAIL; + } + } if (0) { @@ -315,14 +317,14 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & default: { fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); - return false; + return ModelLoadResult::FAIL; } }; if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; + return ModelLoadResult::FAIL; } fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); @@ -342,7 +344,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & fin.close(); - return true; + return ModelLoadResult::SUCCESS; } // evaluate the transformer @@ -584,146 +586,3 @@ bool gptj_eval( return true; } - -// int main(int argc, char ** argv) { -// ggml_time_init(); -// const int64_t t_main_start_us = ggml_time_us(); - -// gpt_params params; -// params.model = "models/gpt-j-6B/ggml-model.bin"; - -// if (utils_gpt_params_parse(argc, argv, params) == false) { -// return 1; -// } - -// if (params.seed < 0) { -// params.seed = time(NULL); -// } - -// printf("%s: seed = %d\n", __func__, params.seed); - -// std::mt19937 rng(params.seed); -// if (params.prompt.empty()) { -// if( !isatty(STDIN_FILENO) ){ -// std::string line; -// while( std::getline(std::cin, line) ){ -// params.prompt = params.prompt + "\n" + line; -// } -// } else { -// params.prompt = utils_gpt_random_prompt(rng); -// } -// } - -// int64_t t_load_us = 0; - -// gpt_vocab vocab; -// gptj_model model; - -// // load the model -// { -// const int64_t t_start_us = ggml_time_us(); - -// if (!gptj_model_load(params.model, model, vocab)) { -// fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); -// return 1; -// } - -// t_load_us = ggml_time_us() - t_start_us; -// } - -// int n_past = 0; - -// int64_t t_sample_us = 0; -// int64_t t_predict_us = 0; - -// std::vector logits; - -// // tokenize the prompt -// std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); - -// params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); - -// printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); -// printf("\n"); - -// std::vector embd; - -// // determine the required inference memory per token: -// size_t mem_per_token = 0; -// gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - -// for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { -// // predict -// if (embd.size() > 0) { -// const int64_t t_start_us = ggml_time_us(); - -// if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { -// printf("Failed to predict\n"); -// return 1; -// } - -// t_predict_us += ggml_time_us() - t_start_us; -// } - -// n_past += embd.size(); -// embd.clear(); - -// if (i >= embd_inp.size()) { -// // sample next token -// const int top_k = params.top_k; -// const float top_p = params.top_p; -// const float temp = params.temp; - -// const int n_vocab = model.hparams.n_vocab; - -// gpt_vocab::id id = 0; - -// { -// const int64_t t_start_sample_us = ggml_time_us(); - -// id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); - -// t_sample_us += ggml_time_us() - t_start_sample_us; -// } - -// // add it to the context -// embd.push_back(id); -// } else { -// // if here, it means we are still processing the input prompt -// for (int k = i; k < embd_inp.size(); k++) { -// embd.push_back(embd_inp[k]); -// if (embd.size() > params.n_batch) { -// break; -// } -// } -// i += embd.size() - 1; -// } - -// // display text -// for (auto id : embd) { -// printf("%s", vocab.id_to_token[id].c_str()); -// } -// fflush(stdout); - -// // end of text token -// if (embd.back() == 50256) { -// break; -// } -// } - -// // report timing -// { -// const int64_t t_main_end_us = ggml_time_us(); - -// printf("\n\n"); -// printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); -// printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); -// printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); -// printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); -// printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); -// } - -// ggml_free(model.ctx); - -// return 0; -// } \ No newline at end of file diff --git a/otherarch/gptj_v2_main.cpp b/otherarch/gptj_v2_main.cpp new file mode 100644 index 000000000..87346476b --- /dev/null +++ b/otherarch/gptj_v2_main.cpp @@ -0,0 +1,145 @@ +#include "gptj_v2.cpp" + + +int main(int argc, char ** argv) { + ggml_time_init(); + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + params.model = "models/gpt-j-6B/ggml-model.bin"; + + if (utils_gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.prompt.empty()) { + if( !isatty(STDIN_FILENO) ){ + std::string line; + while( std::getline(std::cin, line) ){ + params.prompt = params.prompt + "\n" + line; + } + } else { + params.prompt = utils_gpt_random_prompt(rng); + } + } + + int64_t t_load_us = 0; + + gpt_vocab vocab; + gptj_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (gptj_model_load(params.model, model, vocab)==ModelLoadResult::FAIL) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); + + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + printf("\n"); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str()); + } + fflush(stdout); + + // end of text token + if (embd.back() == 50256) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + ggml_free(model.ctx); + + return 0; +} \ No newline at end of file diff --git a/otherarch/otherarch.h b/otherarch/otherarch.h index 21e7103bb..d07257862 100644 --- a/otherarch/otherarch.h +++ b/otherarch/otherarch.h @@ -12,7 +12,7 @@ #include #include "utils.h" - +#include "model_adapter.h" // default hparams (GPT-J 6B) @@ -113,7 +113,7 @@ struct gptj_model { std::map tensors; }; -bool legacy_gptj_model_load(const std::string &fname, gptj_model_v1 &model, gpt_vocab &vocab); -bool legacy_gptj_eval(const gptj_model_v1 &model, const int n_threads, const int n_past, const std::vector &embd_inp, std::vector &embd_w, size_t &mem_per_token); -bool gptj_model_load(const std::string &fname, gptj_model &model, gpt_vocab &vocab); +ModelLoadResult legacy_gptj_model_load(const std::string &fname, gptj_model_v1 &model, gpt_vocab &vocab, FileFormat file_format); +bool legacy_gptj_eval(const gptj_model_v1 &model, const int n_threads, const int n_past, const std::vector &embd_inp, std::vector &embd_w, size_t &mem_per_token, FileFormat file_format); +ModelLoadResult gptj_model_load(const std::string &fname, gptj_model &model, gpt_vocab &vocab); bool gptj_eval(const gptj_model &model, const int n_threads, const int n_past, const std::vector &embd_inp, std::vector &embd_w, size_t &mem_per_token); diff --git a/otherarch/utils.cpp b/otherarch/utils.cpp index 9afdaaf86..352bac182 100644 --- a/otherarch/utils.cpp +++ b/otherarch/utils.cpp @@ -249,6 +249,103 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { return true; } +void gptj_sample_top_k(std::vector> & logits_id, int top_k) { + // find the top K tokens + std::partial_sort( + logits_id.begin(), + logits_id.begin() + top_k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + logits_id.resize(top_k); +} + +gpt_vocab::id gptj_sample_top_p_top_k( + const gpt_vocab & vocab, + const float * logits, + std::vector & last_n_tokens, + double repeat_penalty, + int top_k, + double top_p, + double temp, + std::mt19937 & rng) { + int n_logits = vocab.id_to_token.size(); + + std::vector> logits_id; + logits_id.reserve(n_logits); + + { + const double scale = 1.0/temp; + for (int i = 0; i < n_logits; ++i) { + // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (logits[i] < 0.0) { + logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); + } + } else { + logits_id.push_back(std::make_pair(logits[i]*scale, i)); + } + } + } + + gptj_sample_top_k(logits_id, top_k); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < (int) probs.size(); i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + probs.resize(i + 1); + logits_id.resize(i + 1); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + + //printf("\n"); + //for (int i = 0; i < (int) 10; i++) { + // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); + //} + //printf("\n\n"); + //exit(0); + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; +} + gpt_vocab::id gpt_sample_top_k_top_p( const gpt_vocab & vocab, const float * logits, @@ -327,4 +424,18 @@ gpt_vocab::id gpt_sample_top_k_top_p( int idx = dist(rng); return logits_id[idx].second; +} + +static bool should_transpose_layer(std::string name) +{ + + if(name.find(".mlp.fc_in.weight")!=std::string::npos || + name.find(".attn.out_proj.weight")!=std::string::npos || + name.find(".attn.q_proj.weight")!=std::string::npos || + name.find(".attn.k_proj.weight")!=std::string::npos || + name.find(".attn.v_proj.weight")!=std::string::npos) + { + return true; + } + return false; } \ No newline at end of file diff --git a/otherarch/utils.h b/otherarch/utils.h index 9248868bf..479a364c8 100644 --- a/otherarch/utils.h +++ b/otherarch/utils.h @@ -62,6 +62,18 @@ gpt_vocab::id gpt_sample_top_k_top_p( double temp, std::mt19937 & rng); +gpt_vocab::id gptj_sample_top_p_top_k( + const gpt_vocab & vocab, + const float * logits, + std::vector & last_n_tokens, + double repeat_penalty, + int top_k, + double top_p, + double temp, + std::mt19937 & rng); + bool utils_gpt_params_parse(int argc, char ** argv, gpt_params & params); void utils_gpt_print_usage(int argc, char ** argv, const gpt_params & params); std::string utils_gpt_random_prompt(std::mt19937 & rng); + +static bool should_transpose_layer(std::string name); \ No newline at end of file diff --git a/quantize.exe b/quantize.exe index 744551ffb..5f490498b 100644 Binary files a/quantize.exe and b/quantize.exe differ