Merge branch 'master' into feat/dockerize
This commit is contained in:
commit
60cf70725e
8 changed files with 214 additions and 17 deletions
25
README.md
25
README.md
|
@ -3,10 +3,11 @@
|
|||
[](https://github.com/ggerganov/llama.cpp/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
Inference of [Facebook's LLaMA](https://github.com/facebookresearch/llama) model in pure C/C++
|
||||
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
|
||||
|
||||
**Hot topics:**
|
||||
|
||||
- RMSNorm implementation / fixes: https://github.com/ggerganov/llama.cpp/issues/173
|
||||
- Cache input prompts for faster initialization: https://github.com/ggerganov/llama.cpp/issues/64
|
||||
- Create a `llama.cpp` logo: https://github.com/ggerganov/llama.cpp/issues/105
|
||||
|
||||
|
@ -177,20 +178,38 @@ Note the use of `--color` to distinguish between user input and generated text.
|
|||
|
||||

|
||||
|
||||
### Android
|
||||
|
||||
You can easily run `llama.cpp` on Android device with [termux](https://play.google.com/store/apps/details?id=com.termux).
|
||||
First, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
|
||||
```
|
||||
$ mkdir build-android
|
||||
$ cd build-android
|
||||
$ export NDK=<your_ndk_directory>
|
||||
$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
|
||||
$ make
|
||||
```
|
||||
Install [termux](https://play.google.com/store/apps/details?id=com.termux) on your device and run `termux-setup-storage` to get access to your SD card.
|
||||
Finally, copy the `llama` binary and the model files to your device storage. Here is a demo of an interactive session running on Pixel 5 phone:
|
||||
|
||||
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
|
||||
|
||||
|
||||
## Limitations
|
||||
|
||||
- We don't know yet how much the quantization affects the quality of the generated text
|
||||
- Probably the token sampling can be improved
|
||||
- The Accelerate framework is actually currently unused since I found that for tensor shapes typical for the Decoder,
|
||||
there is no benefit compared to the ARM_NEON intrinsics implementation. Of course, it's possible that I simlpy don't
|
||||
there is no benefit compared to the ARM_NEON intrinsics implementation. Of course, it's possible that I simply don't
|
||||
know how to utilize it properly. But in any case, you can even disable it with `LLAMA_NO_ACCELERATE=1 make` and the
|
||||
performance will be the same, since no BLAS calls are invoked by the current implementation
|
||||
|
||||
### Contributing
|
||||
|
||||
- Contributors can open PRs
|
||||
- Collaborators can push to branches in the `llama.cpp` repo
|
||||
- Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch
|
||||
- Collaborators will be invited based on contributions
|
||||
- Any help with managing issues and PRs is very appreciated!
|
||||
|
||||
### Coding guidelines
|
||||
|
||||
|
|
|
@ -99,7 +99,7 @@ for p in range(n_parts):
|
|||
fout.write(struct.pack("i", ftype))
|
||||
|
||||
# Is this correct??
|
||||
for i in range(32000):
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
if tokenizer.is_unknown(i):
|
||||
# "<unk>" token (translated as ??)
|
||||
text = " \u2047 ".encode("utf-8")
|
||||
|
|
166
ggml.c
166
ggml.c
|
@ -364,7 +364,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|||
#if __AVX2__
|
||||
// Unpack 32 4-bit fields into 32 bytes
|
||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
||||
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
||||
{
|
||||
// Load 16 bytes from memory
|
||||
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
||||
|
@ -381,7 +381,7 @@ inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
|||
return bytes;
|
||||
}
|
||||
|
||||
inline __m128i packNibbles( __m256i bytes )
|
||||
static inline __m128i packNibbles( __m256i bytes )
|
||||
{
|
||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
|
||||
|
@ -1359,8 +1359,8 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|||
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int16x8_t
|
||||
// assume that vdotq_s32 is always available, if not, should check for __ARM_FEATURE_DOTPROD
|
||||
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
||||
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
||||
|
||||
|
@ -1374,6 +1374,37 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|||
#else
|
||||
sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
|
||||
sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
|
||||
#endif
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
||||
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
||||
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
||||
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
||||
|
||||
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
|
||||
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
|
||||
|
||||
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
|
||||
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
||||
|
||||
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
||||
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
||||
|
||||
// scalar
|
||||
#if defined(__ARM_FEATURE_QRDMX)
|
||||
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
|
||||
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
|
||||
#else
|
||||
sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
|
||||
sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -2038,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"GELU",
|
||||
"SILU",
|
||||
"NORM",
|
||||
"RMS_NORM",
|
||||
|
||||
"MUL_MAT",
|
||||
|
||||
|
@ -2058,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|||
"FLASH_FF",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
|
||||
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -2081,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"gelu(x)",
|
||||
"silu(x)",
|
||||
"norm(x)",
|
||||
"rms_norm(x)",
|
||||
|
||||
"X*Y",
|
||||
|
||||
|
@ -2101,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"flash_ff(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
|
||||
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
||||
|
||||
//
|
||||
// ggml object
|
||||
|
@ -3587,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace(
|
|||
return ggml_norm_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rms_norm_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
bool inplace) {
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad)) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_RMS_NORM;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = NULL; // TODO: maybe store epsilon here?
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rms_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_rms_norm_impl(ctx, a, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rms_norm_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_rms_norm_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
// ggml_mul_mat
|
||||
|
||||
struct ggml_tensor * ggml_mul_mat(
|
||||
|
@ -5375,6 +5441,87 @@ static void ggml_compute_forward_norm(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rms_norm_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
const size_t nb01 = src0->nb[1];
|
||||
const size_t nb02 = src0->nb[2];
|
||||
const size_t nb03 = src0->nb[3];
|
||||
|
||||
const size_t nb1 = dst->nb[1];
|
||||
const size_t nb2 = dst->nb[2];
|
||||
const size_t nb3 = dst->nb[3];
|
||||
|
||||
const ggml_float eps = 1e-5f; // TODO: make this a parameter
|
||||
|
||||
// TODO: optimize
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float mean = 0.0;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
mean += x[i00] * x[i00];
|
||||
}
|
||||
|
||||
mean /= ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
// for (int i00 = 0; i00 < ne00; i00++) {
|
||||
// y[i00] = x[i00];
|
||||
// }
|
||||
|
||||
const float scale = 1.0/sqrt(mean + eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rms_norm(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rms_norm_f32(params, src0, dst);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ggml_compute_forward_mul_mat
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
|
@ -8491,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_norm(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
|
||||
|
@ -8733,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
if (src0->grad) {
|
||||
|
@ -9159,6 +9314,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
|
|
5
ggml.h
5
ggml.h
|
@ -230,6 +230,7 @@ enum ggml_op {
|
|||
GGML_OP_GELU,
|
||||
GGML_OP_SILU,
|
||||
GGML_OP_NORM, // normalize
|
||||
GGML_OP_RMS_NORM,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
|
||||
|
@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm(
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_rms_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// A: m rows, n columns
|
||||
// B: p rows, n columns (i.e. we transpose it internally)
|
||||
// result is m columns, p rows
|
||||
|
|
27
main.cpp
27
main.cpp
|
@ -14,6 +14,8 @@
|
|||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#elif defined (_WIN32)
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
#define ANSI_COLOR_RED "\x1b[31m"
|
||||
|
@ -547,6 +549,8 @@ bool llama_eval(
|
|||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
// TODO: check if this size scales with n_ctx linearly and remove constant. somehow I feel it wasn't the case
|
||||
// static size_t buf_size = hparams.n_ctx*1024*1024;
|
||||
static size_t buf_size = 512u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
|
@ -584,7 +588,7 @@ bool llama_eval(
|
|||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
|
@ -674,7 +678,7 @@ bool llama_eval(
|
|||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
|
@ -709,7 +713,7 @@ bool llama_eval(
|
|||
|
||||
// norm
|
||||
{
|
||||
inpL = ggml_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
|
||||
// inpL = norm*inpL
|
||||
inpL = ggml_mul(ctx0,
|
||||
|
@ -753,8 +757,9 @@ bool llama_eval(
|
|||
|
||||
static bool is_interacting = false;
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
void sigint_handler(int signo) {
|
||||
printf(ANSI_COLOR_RESET);
|
||||
if (signo == SIGINT) {
|
||||
if (!is_interacting) {
|
||||
is_interacting=true;
|
||||
|
@ -818,8 +823,7 @@ int main(int argc, char ** argv) {
|
|||
// load the model
|
||||
{
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!llama_model_load(params.model, model, vocab, 512)) { // TODO: set context from user input ??
|
||||
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
@ -863,6 +867,8 @@ int main(int argc, char ** argv) {
|
|||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
signal(SIGINT, sigint_handler);
|
||||
#endif
|
||||
|
||||
fprintf(stderr, "%s: interactive mode on.\n", __func__);
|
||||
|
@ -892,7 +898,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (params.interactive) {
|
||||
fprintf(stderr, "== Running in interactive mode. ==\n"
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
" - Press Ctrl+C to interject at any time.\n"
|
||||
#endif
|
||||
" - Press Return to return control to LLaMa.\n"
|
||||
|
@ -1037,6 +1043,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
#if defined (_WIN32)
|
||||
signal(SIGINT, SIG_DFL);
|
||||
#endif
|
||||
|
||||
// report timing
|
||||
{
|
||||
|
@ -1052,5 +1061,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
ggml_free(model.ctx);
|
||||
|
||||
if (params.use_color) {
|
||||
printf(ANSI_COLOR_RESET);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
0
models/.gitignore
vendored
0
models/.gitignore
vendored
|
@ -37,6 +37,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.n_predict = std::stoi(argv[++i]);
|
||||
} else if (arg == "--top_k") {
|
||||
params.top_k = std::stoi(argv[++i]);
|
||||
} else if (arg == "-c" || arg == "--ctx_size") {
|
||||
params.n_ctx = std::stoi(argv[++i]);
|
||||
} else if (arg == "--top_p") {
|
||||
params.top_p = std::stof(argv[++i]);
|
||||
} else if (arg == "--temp") {
|
||||
|
@ -92,6 +94,7 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
|
|||
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
||||
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
|
||||
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
|
||||
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
|
|
3
utils.h
3
utils.h
|
@ -17,7 +17,8 @@ struct gpt_params {
|
|||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_predict = 128; // new tokens to predict
|
||||
int32_t repeat_last_n = 64; // last n tokens to penalize
|
||||
|
||||
int32_t n_ctx = 512; //context size
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.95f;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue