From 465263d0cf1e8f8bc41948332dbd009d27a68590 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Wed, 8 May 2024 14:29:23 +0000 Subject: [PATCH 1/2] sgemm : AVX Q4_0 and Q8_0 (#6891) * basic avx implementation * style * combine denibble with load * reduce 256 to 128 (and back!) conversions * sse load * Update sgemm.cpp * oops oops --- sgemm.cpp | 77 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index 4e0159804..40ba9d7e9 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -1,6 +1,3 @@ -// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- -// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi -// // Copyright 2024 Mozilla Foundation // // Permission is hereby granted, free of charge, to any person obtaining @@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM { }; #endif // __ARM_FEATURE_DOTPROD -#if defined(__AVX2__) || defined(__AVX512F__) +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) template -class tinyBLAS_Q0_AVX2 { +class tinyBLAS_Q0_AVX { public: - tinyBLAS_Q0_AVX2(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) + tinyBLAS_Q0_AVX(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 { __m256 Cv[RN][RM] = {}; for (int64_t l = 0; l < k; ++l) for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) + for (int64_t i = 0; i < RM; ++i) { +#if defined(__AVX2__) + __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))); +#else + __m128i ali0 = load0(A + lda * (ii + i) + l); + __m128i ali1 = load1(A + lda * (ii + i) + l); + __m128i blj0 = load0(B + ldb * (jj + j) + l); + __m128i blj1 = load1(B + ldb * (jj + j) + l); + + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); + __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); + __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); + __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); + + // updot + const __m128i oneFill = _mm_set1_epi16(1); + __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); + __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); +#endif Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))), - Cv[j][i]); + udTmp, + Cv[j][i]); + } for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); @@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 { return _mm256_loadu_si256((const __m256i *)b->qs); } + inline __m128i load0(const block_q8_0 *b) { + return _mm_loadu_si128((const __m128i *)b->qs); + } + + inline __m128i load1(const block_q8_0 *b) { + return _mm_loadu_si128(((const __m128i *)b->qs) + 1); + } + inline __m256i load(const block_q4_0 *b) { return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); } + inline __m128i load0(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); + } + + inline __m128i load1(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); + } + inline __m256 updot(__m256i u, __m256i s) { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) @@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 { const int ith; const int nth; }; -#endif // __AVX2__ +#endif // __AVX__ } // namespace @@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_Q8_0: { if (Btype != GGML_TYPE_Q8_0) return false; -#if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_Q4_0: { if (Btype != GGML_TYPE_Q8_0) return false; -#if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, From 83330d8cd6491e53e1aca4c5dfc47e039b3c04ff Mon Sep 17 00:00:00 2001 From: Dawid Potocki Date: Thu, 9 May 2024 02:32:32 +1200 Subject: [PATCH 2/2] main : add --conversation / -cnv flag (#7108) --- common/common.cpp | 5 +++++ common/common.h | 1 + examples/main/main.cpp | 11 +++++++---- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 467fb014e..4a9da284e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -911,6 +911,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.instruct = true; return true; } + if (arg == "-cnv" || arg == "--conversation") { + params.conversation = true; + return true; + } if (arg == "-cml" || arg == "--chatml") { params.chatml = true; return true; @@ -1417,6 +1421,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --version show version and build info\n"); printf(" -i, --interactive run in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); + printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); diff --git a/common/common.h b/common/common.h index 9252a4b63..6f00a2cca 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,7 @@ struct gpt_params { bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode + bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it diff --git a/examples/main/main.cpp b/examples/main/main.cpp index f676ea1ba..49acd6bab 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -362,6 +362,9 @@ int main(int argc, char ** argv) { params.interactive_first = true; params.antiprompt.emplace_back("<|im_start|>user\n"); } + else if (params.conversation) { + params.interactive_first = true; + } // enable interactive mode if interactive start is specified if (params.interactive_first) { @@ -733,7 +736,7 @@ int main(int argc, char ** argv) { // display text if (input_echo && display) { for (auto id : embd) { - const std::string token_str = llama_token_to_piece(ctx, id); + const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation); printf("%s", token_str.c_str()); if (embd.size() > 1) { @@ -816,7 +819,7 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); - if (params.instruct || params.chatml) { + if (params.conversation || params.instruct || params.chatml) { printf("\n> "); } @@ -826,7 +829,7 @@ int main(int argc, char ** argv) { } std::string buffer; - if (!params.input_prefix.empty()) { + if (!params.input_prefix.empty() && !params.conversation) { LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); printf("%s", params.input_prefix.c_str()); } @@ -850,7 +853,7 @@ int main(int argc, char ** argv) { // Entering a empty line lets the user pass control back if (buffer.length() > 1) { // append input suffix if any - if (!params.input_suffix.empty()) { + if (!params.input_suffix.empty() && !params.conversation) { LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); printf("%s", params.input_suffix.c_str()); }