Merge branch 'ggerganov:master' into master
This commit is contained in:
commit
498fb769a5
12 changed files with 79 additions and 52 deletions
|
@ -768,7 +768,7 @@ The time per token is measured on a MacBook M1 Pro 32GB RAM using 4 and 8 thread
|
|||
|
||||
#### How to run
|
||||
|
||||
1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
1. Download/extract: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw`
|
||||
3. Output:
|
||||
```
|
||||
|
|
|
@ -219,7 +219,7 @@ function gg_run_open_llama_3b_v2 {
|
|||
gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/resolve/main/pytorch_model.bin
|
||||
gg_wget models-mnt/open-llama/3B-v2/ https://huggingface.co/openlm-research/open_llama_3b_v2/raw/main/generation_config.json
|
||||
|
||||
gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
||||
gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/
|
||||
head -n 60 models-mnt/wikitext/wikitext-2-raw/wiki.test.raw > models-mnt/wikitext/wikitext-2-raw/wiki.test-60.raw
|
||||
|
||||
|
@ -401,7 +401,7 @@ function gg_run_open_llama_7b_v2 {
|
|||
gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/resolve/main/pytorch_model-00002-of-00002.bin
|
||||
gg_wget models-mnt/open-llama/7B-v2/ https://huggingface.co/openlm-research/open_llama_7b_v2/raw/main/generation_config.json
|
||||
|
||||
gg_wget models-mnt/wikitext/ https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
||||
gg_wget models-mnt/wikitext/ https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
unzip -o models-mnt/wikitext/wikitext-2-raw-v1.zip -d models-mnt/wikitext/
|
||||
|
||||
path_models="../models-mnt/open-llama/7B-v2"
|
||||
|
|
|
@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
}
|
||||
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
|
||||
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
|
||||
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
|
||||
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
|
||||
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
|
||||
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
|
||||
|
|
|
@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl(
|
|||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||
} else {
|
||||
sampler_queue(ctx_main, params, cur_p, 1);
|
||||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.min_keep);
|
||||
|
||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||
|
||||
id = llama_sample_token(ctx_main, &cur_p);
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ enum class llama_sampler_type : char {
|
|||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float min_p = 0.05f; // 0.0 = disabled
|
||||
|
|
|
@ -309,7 +309,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
|||
}
|
||||
|
||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
@ -447,7 +447,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
return perplexity_v2(ctx, params);
|
||||
}
|
||||
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
|
|
|
@ -199,6 +199,8 @@ node index.js
|
|||
|
||||
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
|
||||
|
||||
`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0)
|
||||
|
||||
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
|
||||
|
||||
`slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1)
|
||||
|
|
|
@ -234,6 +234,7 @@
|
|||
mirostat_eta: 0.1, // learning rate
|
||||
grammar: '',
|
||||
n_probs: 0, // no completion_probabilities,
|
||||
min_keep: 0, // min probs from each sampler,
|
||||
image_data: [],
|
||||
cache_prompt: true,
|
||||
api_key: ''
|
||||
|
@ -791,6 +792,9 @@
|
|||
<fieldset>
|
||||
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<label for="api_key">API Key</label>
|
||||
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />
|
||||
|
|
|
@ -548,6 +548,7 @@ struct llama_server_context
|
|||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
|
||||
// Might be better to reject the request with a 400 ?
|
||||
|
@ -1093,6 +1094,7 @@ struct llama_server_context
|
|||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence}
|
||||
};
|
||||
|
|
|
@ -4027,7 +4027,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4170,7 +4173,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4306,7 +4312,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4424,7 +4433,10 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
y4 += 16 * 32;
|
||||
}
|
||||
#else
|
||||
// TODO
|
||||
(void) x;
|
||||
(void) y;
|
||||
(void) yl;
|
||||
(void) nb32;
|
||||
#endif
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
|
@ -4659,6 +4671,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
#else
|
||||
(void) get_scale_min_k4_just2;
|
||||
|
||||
q = q + 16 * (il&1);
|
||||
device const uint8_t * s = xb->scales;
|
||||
device const half2 * dh = (device const half2 *)xb->d;
|
||||
|
|
|
@ -1837,9 +1837,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
|||
float sigma2 = sumx2/QK_K;
|
||||
for (int j = 0; j < QK_K/16; ++j) {
|
||||
const float * restrict qw = quant_weights + QK_K * i + 16*j;
|
||||
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
||||
for (int l = 0; l < 16; ++l) sw[j] += weight[l];
|
||||
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
||||
for (int l = 0; l < QK_K/16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
|
||||
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
|
||||
scales[j] = make_qkx3_quants(QK_K/16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
|
||||
}
|
||||
|
||||
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
|
||||
|
@ -3855,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -3866,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q4_0 * restrict x = vx;
|
||||
|
@ -4024,15 +4024,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
|
||||
__m128i bx = _mm_and_si128(lowMask, tmp);
|
||||
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
|
||||
bx = _mm_sub_epi8(bx, off);
|
||||
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
|
||||
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
|
||||
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
|
||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
||||
|
||||
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
||||
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
||||
bx = _mm_sub_epi8(bx, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
|
||||
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
||||
by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
|
||||
bx_0 = _mm_sub_epi8(bx_0, off);
|
||||
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
|
||||
|
||||
// Convert int32_t to float
|
||||
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
|
||||
|
@ -4222,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -4233,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q4_1 * restrict x = vx;
|
||||
|
@ -4440,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -4448,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(qk == QK5_0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_0 * restrict x = vx;
|
||||
|
@ -4618,21 +4618,21 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
/* Compute combined scale for the block */
|
||||
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
__m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
||||
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
||||
bxhil = _mm_andnot_si128(bxhil, mask);
|
||||
bxhih = _mm_andnot_si128(bxhih, mask);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx_0);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx_0, 1);
|
||||
bxl = _mm_or_si128(bxl, bxhil);
|
||||
bxh = _mm_or_si128(bxh, bxhih);
|
||||
bx = MM256_SET_M128I(bxh, bxl);
|
||||
bx_0 = MM256_SET_M128I(bxh, bxl);
|
||||
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
|
||||
|
||||
/* Multiply q with scale and accumulate */
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
|
||||
|
@ -4731,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -4739,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(qk == QK5_1);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q5_1 * restrict x = vx;
|
||||
|
@ -4925,22 +4925,22 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
__m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
||||
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
||||
bxhil = _mm_and_si128(bxhil, mask);
|
||||
bxhih = _mm_and_si128(bxhih, mask);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx_0);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx_0, 1);
|
||||
bxl = _mm_or_si128(bxl, bxhil);
|
||||
bxh = _mm_or_si128(bxh, bxhih);
|
||||
bx = MM256_SET_M128I(bxh, bxl);
|
||||
bx_0 = MM256_SET_M128I(bxh, bxl);
|
||||
|
||||
const __m256 dy = _mm256_set1_ps(y[i].d);
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_us8_pairs_float(bx, by);
|
||||
const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
|
||||
}
|
||||
|
@ -5035,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) {
|
||||
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
|
@ -5046,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
assert(nrc == 1);
|
||||
#endif
|
||||
UNUSED(nrc);
|
||||
UNUSED(bbx);
|
||||
UNUSED(bby);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q8_0 * restrict x = vx;
|
||||
|
@ -5169,10 +5169,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// load elements
|
||||
vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
|
||||
vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
|
||||
vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
|
||||
vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
|
||||
|
||||
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
|
||||
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
|
||||
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
||||
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
|
||||
wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
|
||||
|
||||
echo "Usage:"
|
||||
echo ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue