Merge branch 'ggerganov:master' into fix-cmake-pthread

This commit is contained in:
mmyjona 2023-03-16 10:25:05 +08:00 committed by GitHub
commit 610719ecc8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 226 additions and 22 deletions

View file

@ -3,14 +3,19 @@
[![Actions Status](https://github.com/ggerganov/llama.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/llama.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
Inference of [Facebook's LLaMA](https://github.com/facebookresearch/llama) model in pure C/C++
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
**Hot topics:**
- Cache input prompts for faster initialization: https://github.com/ggerganov/llama.cpp/issues/64
- Create a `llama.cpp` logo: https://github.com/ggerganov/llama.cpp/issues/105
## Description
The main goal is to run the model using 4-bit quantization on a MacBook
- Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework
- Apple silicon first-class citizen - optimized via ARM NEON
- AVX2 support for x86 architectures
- Mixed F16 / F32 precision
- 4-bit quantization support
@ -172,12 +177,29 @@ Note the use of `--color` to distinguish between user input and generated text.
![image](https://user-images.githubusercontent.com/1991296/224575029-2af3c7dc-5a65-4f64-a6bb-517a532aea38.png)
### Android
You can easily run `llama.cpp` on Android device with [termux](https://play.google.com/store/apps/details?id=com.termux).
First, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
```
$ mkdir build-android
$ cd build-android
$ export NDK=<your_ndk_directory>
$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
$ make
```
Install [termux](https://play.google.com/store/apps/details?id=com.termux) on your device and run `termux-setup-storage` to get access to your SD card.
Finally, copy the `llama` binary and the model files to your device storage. Here is a demo of an interactive session running on Pixel 5 phone:
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
## Limitations
- I don't know yet how much the quantization affects the quality of the generated text
- We don't know yet how much the quantization affects the quality of the generated text
- Probably the token sampling can be improved
- The Accelerate framework is actually currently unused since I found that for tensor shapes typical for the Decoder,
there is no benefit compared to the ARM_NEON intrinsics implementation. Of course, it's possible that I simlpy don't
there is no benefit compared to the ARM_NEON intrinsics implementation. Of course, it's possible that I simply don't
know how to utilize it properly. But in any case, you can even disable it with `LLAMA_NO_ACCELERATE=1 make` and the
performance will be the same, since no BLAS calls are invoked by the current implementation
@ -187,11 +209,15 @@ Note the use of `--color` to distinguish between user input and generated text.
- Collaborators can push to branches in the `llama.cpp` repo
- Collaborators will be invited based on contributions
### Coding guide-lines
### Coding guidelines
- Avoid adding third-party dependencies, extra files, extra headers, etc.
- Always consider cross-compatibility with other operating systems and architectures
- Avoid fancy looking modern STL constructs, use basic for loops, avoid templates, keep it simple
- Avoid fancy looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple
- There are no strict rules for the code style, but try to follow the patterns in the code (indentation, spaces, etc.). Vertical alignment makes things more readable and easier to batch edit
- Clean-up any tailing whitespaces, use 4 spaces indentation, brackets on same line, `int * var`
- Look at the [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks
- Clean-up any trailing whitespaces, use 4 spaces indentation, brackets on same line, `void * ptr`, `int & a`
- See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions
### Misc
- Practice your C++ typing skills: https://typing-battles.ggerganov.com

View file

@ -99,7 +99,7 @@ for p in range(n_parts):
fout.write(struct.pack("i", ftype))
# Is this correct??
for i in range(32000):
for i in range(tokenizer.vocab_size()):
if tokenizer.is_unknown(i):
# "<unk>" token (translated as ??)
text = " \u2047 ".encode("utf-8")

166
ggml.c
View file

@ -364,7 +364,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
#if __AVX2__
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
inline __m256i bytesFromNibbles( const uint8_t* rsi )
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
{
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@ -381,7 +381,7 @@ inline __m256i bytesFromNibbles( const uint8_t* rsi )
return bytes;
}
inline __m128i packNibbles( __m256i bytes )
static inline __m128i packNibbles( __m256i bytes )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
@ -1359,8 +1359,8 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int16x8_t
// assume that vdotq_s32 is always available, if not, should check for __ARM_FEATURE_DOTPROD
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
@ -1374,6 +1374,37 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
#else
sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
#endif
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
// scalar
#if defined(__ARM_FEATURE_QRDMX)
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
#else
sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
#endif
#endif
}
@ -2038,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"GELU",
"SILU",
"NORM",
"RMS_NORM",
"MUL_MAT",
@ -2058,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_FF",
};
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -2081,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"gelu(x)",
"silu(x)",
"norm(x)",
"rms_norm(x)",
"X*Y",
@ -2101,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_ff(x)",
};
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
//
// ggml object
@ -3587,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace(
return ggml_norm_impl(ctx, a, true);
}
struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL; // TODO: maybe store epsilon here?
return result;
}
struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, false);
}
struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, true);
}
// ggml_mul_mat
struct ggml_tensor * ggml_mul_mat(
@ -5375,6 +5441,87 @@ static void ggml_compute_forward_norm(
}
}
static void ggml_compute_forward_rms_norm_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2];
const size_t nb03 = src0->nb[3];
const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
const ggml_float eps = 1e-5f; // TODO: make this a parameter
// TODO: optimize
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
for (int i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float mean = 0.0;
for (int i00 = 0; i00 < ne00; i00++) {
mean += x[i00] * x[i00];
}
mean /= ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }
const float scale = 1.0/sqrt(mean + eps);
ggml_vec_scale_f32(ne00, y, scale);
}
}
}
}
static void ggml_compute_forward_rms_norm(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_rms_norm_f32(params, src0, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_F16:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_mul_mat
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@ -8491,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_norm(params, tensor->src0, tensor);
} break;
case GGML_OP_RMS_NORM:
{
ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@ -8733,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_MUL_MAT:
{
if (src0->grad) {
@ -9159,6 +9314,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
node->n_tasks = n_threads;
} break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
{
node->n_tasks = n_threads;
} break;

5
ggml.h
View file

@ -230,6 +230,7 @@ enum ggml_op {
GGML_OP_GELU,
GGML_OP_SILU,
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
GGML_OP_MUL_MAT,
@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: m rows, n columns
// B: p rows, n columns (i.e. we transpose it internally)
// result is m columns, p rows

View file

@ -14,6 +14,8 @@
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#include <signal.h>
#endif
#define ANSI_COLOR_RED "\x1b[31m"
@ -547,6 +549,8 @@ bool llama_eval(
const int d_key = n_embd/n_head;
// TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
// static size_t buf_size = hparams.n_ctx*1024*1024;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
@ -584,7 +588,7 @@ bool llama_eval(
// norm
{
cur = ggml_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
@ -674,7 +678,7 @@ bool llama_eval(
{
// norm
{
cur = ggml_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF);
// cur = ffn_norm*cur
cur = ggml_mul(ctx0,
@ -709,7 +713,7 @@ bool llama_eval(
// norm
{
inpL = ggml_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL);
// inpL = norm*inpL
inpL = ggml_mul(ctx0,
@ -753,8 +757,9 @@ bool llama_eval(
static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
printf(ANSI_COLOR_RESET);
if (signo == SIGINT) {
if (!is_interacting) {
is_interacting=true;
@ -818,8 +823,7 @@ int main(int argc, char ** argv) {
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ??
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
@ -863,6 +867,8 @@ int main(int argc, char ** argv) {
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
signal(SIGINT, sigint_handler);
#endif
fprintf(stderr, "%s: interactive mode on.\n", __func__);
@ -892,7 +898,7 @@ int main(int argc, char ** argv) {
if (params.interactive) {
fprintf(stderr, "== Running in interactive mode. ==\n"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
" - Press Ctrl+C to interject at any time.\n"
#endif
" - Press Return to return control to LLaMa.\n"
@ -1037,6 +1043,9 @@ int main(int argc, char ** argv) {
}
}
#if defined (_WIN32)
signal(SIGINT, SIG_DFL);
#endif
// report timing
{
@ -1052,5 +1061,9 @@ int main(int argc, char ** argv) {
ggml_free(model.ctx);
if (params.use_color) {
printf(ANSI_COLOR_RESET);
}
return 0;
}

0
models/.gitignore vendored
View file

View file

@ -37,6 +37,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") {
params.top_k = std::stoi(argv[++i]);
} else if (arg == "-c" || arg == "--ctx_size") {
params.n_ctx = std::stoi(argv[++i]);
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
@ -92,6 +94,7 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");

View file

@ -17,6 +17,7 @@ struct gpt_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 128; // new tokens to predict
int32_t repeat_last_n = 64; // last n tokens to penalize
int32_t n_ctx = 512; //context size
// sampling parameters
int32_t top_k = 40;