diff --git a/README.md b/README.md index 085f19e03..c6fb427e2 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,11 @@ [![Actions Status](https://github.com/ggerganov/llama.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/llama.cpp/actions) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) -Inference of [Facebook's LLaMA](https://github.com/facebookresearch/llama) model in pure C/C++ +Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ **Hot topics:** +- RMSNorm implementation / fixes: https://github.com/ggerganov/llama.cpp/issues/173 - Cache input prompts for faster initialization: https://github.com/ggerganov/llama.cpp/issues/64 - Create a `llama.cpp` logo: https://github.com/ggerganov/llama.cpp/issues/105 @@ -177,20 +178,38 @@ Note the use of `--color` to distinguish between user input and generated text. ![image](https://user-images.githubusercontent.com/1991296/224575029-2af3c7dc-5a65-4f64-a6bb-517a532aea38.png) +### Android + +You can easily run `llama.cpp` on Android device with [termux](https://play.google.com/store/apps/details?id=com.termux). +First, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake: +``` +$ mkdir build-android +$ cd build-android +$ export NDK= +$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. +$ make +``` +Install [termux](https://play.google.com/store/apps/details?id=com.termux) on your device and run `termux-setup-storage` to get access to your SD card. +Finally, copy the `llama` binary and the model files to your device storage. Here is a demo of an interactive session running on Pixel 5 phone: + +https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 + + ## Limitations - We don't know yet how much the quantization affects the quality of the generated text - Probably the token sampling can be improved - The Accelerate framework is actually currently unused since I found that for tensor shapes typical for the Decoder, - there is no benefit compared to the ARM_NEON intrinsics implementation. Of course, it's possible that I simlpy don't + there is no benefit compared to the ARM_NEON intrinsics implementation. Of course, it's possible that I simply don't know how to utilize it properly. But in any case, you can even disable it with `LLAMA_NO_ACCELERATE=1 make` and the performance will be the same, since no BLAS calls are invoked by the current implementation ### Contributing - Contributors can open PRs -- Collaborators can push to branches in the `llama.cpp` repo +- Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch - Collaborators will be invited based on contributions +- Any help with managing issues and PRs is very appreciated! ### Coding guidelines diff --git a/convert-pth-to-ggml.py b/convert-pth-to-ggml.py index d2557500a..5c36e9c09 100644 --- a/convert-pth-to-ggml.py +++ b/convert-pth-to-ggml.py @@ -99,7 +99,7 @@ for p in range(n_parts): fout.write(struct.pack("i", ftype)) # Is this correct?? - for i in range(32000): + for i in range(tokenizer.vocab_size()): if tokenizer.is_unknown(i): # "" token (translated as ??) text = " \u2047 ".encode("utf-8") diff --git a/ggml.c b/ggml.c index 58a4c9b6d..535c7b7d2 100644 --- a/ggml.c +++ b/ggml.c @@ -364,7 +364,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); #if __AVX2__ // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -inline __m256i bytesFromNibbles( const uint8_t* rsi ) +static inline __m256i bytesFromNibbles( const uint8_t* rsi ) { // Load 16 bytes from memory __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi ); @@ -381,7 +381,7 @@ inline __m256i bytesFromNibbles( const uint8_t* rsi ) return bytes; } -inline __m128i packNibbles( __m256i bytes ) +static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh const __m256i lowByte = _mm256_set1_epi16( 0xFF ); @@ -1359,8 +1359,8 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); +#if defined(__ARM_FEATURE_DOTPROD) // dot product into int16x8_t - // assume that vdotq_s32 is always available, if not, should check for __ARM_FEATURE_DOTPROD int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls); int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls); @@ -1374,6 +1374,37 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void #else sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3)); sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3)); +#endif +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); + const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); + + const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); + const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); + + // scalar +#if defined(__ARM_FEATURE_QRDMX) + sum0 += d0_0*d1_0*vaddvq_s16(p_0); + sum1 += d0_1*d1_1*vaddvq_s16(p_1); +#else + sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); + sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); +#endif #endif } @@ -2038,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "GELU", "SILU", "NORM", + "RMS_NORM", "MUL_MAT", @@ -2058,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; -static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); +static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2081,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gelu(x)", "silu(x)", "norm(x)", + "rms_norm(x)", "X*Y", @@ -2101,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; -static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); +static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); // // ggml object @@ -3587,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace( return ggml_norm_impl(ctx, a, true); } +struct ggml_tensor * ggml_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_RMS_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_rms_norm_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_rms_norm_impl(ctx, a, true); +} + // ggml_mul_mat struct ggml_tensor * ggml_mul_mat( @@ -5375,6 +5441,87 @@ static void ggml_compute_forward_norm( } } +static void ggml_compute_forward_rms_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const ggml_float eps = 1e-5f; // TODO: make this a parameter + + // TODO: optimize + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float mean = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + mean += x[i00] * x[i00]; + } + + mean /= ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0/sqrt(mean + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_rms_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, src0, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + + // ggml_compute_forward_mul_mat #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) @@ -8491,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_norm(params, tensor->src0, tensor); } break; + case GGML_OP_RMS_NORM: + { + ggml_compute_forward_rms_norm(params, tensor->src0, tensor); + } break; case GGML_OP_MUL_MAT: { ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); @@ -8733,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_MUL_MAT: { if (src0->grad) { @@ -9159,6 +9314,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) node->n_tasks = n_threads; } break; case GGML_OP_NORM: + case GGML_OP_RMS_NORM: { node->n_tasks = n_threads; } break; diff --git a/ggml.h b/ggml.h index 7ce655c1b..bac4fe65c 100644 --- a/ggml.h +++ b/ggml.h @@ -230,6 +230,7 @@ enum ggml_op { GGML_OP_GELU, GGML_OP_SILU, GGML_OP_NORM, // normalize + GGML_OP_RMS_NORM, GGML_OP_MUL_MAT, @@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm( struct ggml_context * ctx, struct ggml_tensor * a); +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a); + // A: m rows, n columns // B: p rows, n columns (i.e. we transpose it internally) // result is m columns, p rows diff --git a/main.cpp b/main.cpp index 6dc9ae980..ca0fca8b3 100644 --- a/main.cpp +++ b/main.cpp @@ -14,6 +14,8 @@ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include #include +#elif defined (_WIN32) +#include #endif #define ANSI_COLOR_RED "\x1b[31m" @@ -547,6 +549,8 @@ bool llama_eval( const int d_key = n_embd/n_head; + // TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case + // static size_t buf_size = hparams.n_ctx*1024*1024; static size_t buf_size = 512u*1024*1024; static void * buf = malloc(buf_size); @@ -584,7 +588,7 @@ bool llama_eval( // norm { - cur = ggml_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL); // cur = attention_norm*cur cur = ggml_mul(ctx0, @@ -674,7 +678,7 @@ bool llama_eval( { // norm { - cur = ggml_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF); // cur = ffn_norm*cur cur = ggml_mul(ctx0, @@ -709,7 +713,7 @@ bool llama_eval( // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL); // inpL = norm*inpL inpL = ggml_mul(ctx0, @@ -753,8 +757,9 @@ bool llama_eval( static bool is_interacting = false; -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { + printf(ANSI_COLOR_RESET); if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; @@ -818,8 +823,7 @@ int main(int argc, char ** argv) { // load the model { const int64_t t_start_us = ggml_time_us(); - - if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ?? + if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } @@ -863,6 +867,8 @@ int main(int argc, char ** argv) { sigemptyset (&sigint_action.sa_mask); sigint_action.sa_flags = 0; sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + signal(SIGINT, sigint_handler); #endif fprintf(stderr, "%s: interactive mode on.\n", __func__); @@ -892,7 +898,7 @@ int main(int argc, char ** argv) { if (params.interactive) { fprintf(stderr, "== Running in interactive mode. ==\n" -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif " - Press Return to return control to LLaMa.\n" @@ -1037,6 +1043,9 @@ int main(int argc, char ** argv) { } } +#if defined (_WIN32) + signal(SIGINT, SIG_DFL); +#endif // report timing { @@ -1052,5 +1061,9 @@ int main(int argc, char ** argv) { ggml_free(model.ctx); + if (params.use_color) { + printf(ANSI_COLOR_RESET); + } + return 0; } diff --git a/models/.gitignore b/models/.gitignore deleted file mode 100644 index e69de29bb..000000000 diff --git a/utils.cpp b/utils.cpp index 54217f02f..aa3ad1053 100644 --- a/utils.cpp +++ b/utils.cpp @@ -37,6 +37,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.n_predict = std::stoi(argv[++i]); } else if (arg == "--top_k") { params.top_k = std::stoi(argv[++i]); + } else if (arg == "-c" || arg == "--ctx_size") { + params.n_ctx = std::stoi(argv[++i]); } else if (arg == "--top_p") { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { @@ -92,6 +94,7 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); + fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -m FNAME, --model FNAME\n"); diff --git a/utils.h b/utils.h index 4f98011cf..021120b05 100644 --- a/utils.h +++ b/utils.h @@ -17,7 +17,8 @@ struct gpt_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 128; // new tokens to predict int32_t repeat_last_n = 64; // last n tokens to penalize - + int32_t n_ctx = 512; //context size + // sampling parameters int32_t top_k = 40; float top_p = 0.95f;