llama : support negative ith in llama_get_ API (#6519)

* llama_sampling_sample with default args is more naively usable

* Batches populated by either llama_batch_get_one or llama_batch_add work with default args
  * Previously get_one could use the default argument
  * Previously add should usually have used the last index where logits[idx] == true
* This hopefully encourages the use of llama_batch_add
  * By giving expected results when using default arguments.
* Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith
* Believed to work with any currently well behaved program
  * Default arg now works for both cases (previously would give strange results for add case)
  * Any non-negative number is unaffected and behaves as previously
  * Negative arguments were previously invalid.
* Implemented as a special case of indexing as suggested by @compilade in https://github.com/ggerganov/llama.cpp/pull/6519

* Fixed mismatch type errors

* cited in macOS CI tests
* Missed in original updates based on PR feedback in https://github.com/ggerganov/llama.cpp/pull/6519
This commit is contained in:
Rick G 2024-04-08 06:02:30 -07:00 committed by GitHub
parent beea6e1b16
commit e3c337d87c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 12 deletions

View file

@ -747,8 +747,9 @@ extern "C" {
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Logits for the ith token. Equivalent to:
// Logits for the ith token. For positive indices, Equivalent to:
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
// Negative indicies can be used to access logits in reverse order, -1 is the last logit.
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
@ -760,8 +761,9 @@ extern "C" {
// Otherwise, returns NULL.
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Get the embeddings for the ith token. Equivalent to:
// Get the embeddings for the ith token. For positive indices, Equivalent to:
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
// Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
// shape: [n_embd] (1-dimensional)
// returns NULL for invalid ids.
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);