Updated Swift and Android bindings to use the new llama_sampling_* refactor from #8643
This commit is contained in:
parent
dbf85440c7
commit
b6c9b539c9
4 changed files with 15 additions and 9 deletions
|
@ -44,6 +44,8 @@ context_params.n_threads = 8
|
||||||
context_params.n_threads_batch = 8
|
context_params.n_threads_batch = 8
|
||||||
|
|
||||||
let context = llama_new_context_with_model(model, context_params)
|
let context = llama_new_context_with_model(model, context_params)
|
||||||
|
let smpl = llama_get_sampling(context)
|
||||||
|
|
||||||
guard context != nil else {
|
guard context != nil else {
|
||||||
print("Failed to initialize context")
|
print("Failed to initialize context")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
@ -144,13 +146,13 @@ while n_cur <= n_len {
|
||||||
let top_p: Float = 0.9
|
let top_p: Float = 0.9
|
||||||
let temp: Float = 0.4
|
let temp: Float = 0.4
|
||||||
|
|
||||||
llama_sample_top_k(context, &candidates_p, top_k, 1)
|
llama_sampling_top_k(smpl, &candidates_p, top_k, 1)
|
||||||
llama_sample_top_p(context, &candidates_p, top_p, 1)
|
llama_sampling_top_p(smpl, &candidates_p, top_p, 1)
|
||||||
llama_sample_temp(context, &candidates_p, temp)
|
llama_sampling_temp(smpl, &candidates_p, temp)
|
||||||
|
|
||||||
let new_token_id = llama_sample_token(context, &candidates_p)
|
let new_token_id = llama_sampling_sample(smpl, &candidates_p)
|
||||||
|
|
||||||
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
|
||||||
|
|
||||||
// is it an end of stream? -> mark the stream as finished
|
// is it an end of stream? -> mark the stream as finished
|
||||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||||
|
@ -212,7 +214,7 @@ let t_main_end = ggml_time_us()
|
||||||
|
|
||||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
|
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
|
||||||
|
|
||||||
llama_print_timings(context)
|
llama_print_timings(context, smpl, nil)
|
||||||
|
|
||||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||||
let utf8Count = text.utf8.count
|
let utf8Count = text.utf8.count
|
||||||
|
|
|
@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||||
jobject intvar_ncur
|
jobject intvar_ncur
|
||||||
) {
|
) {
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
|
const auto sampling = reinterpret_cast<llama_sampling *>(llama_get_sampling(context));
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||||
const auto model = llama_get_model(context);
|
const auto model = llama_get_model(context);
|
||||||
|
|
||||||
|
@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
// sample the most likely token
|
// sample the most likely token
|
||||||
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
|
const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p);
|
||||||
|
|
||||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||||
|
|
|
@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
||||||
actor LlamaContext {
|
actor LlamaContext {
|
||||||
private var model: OpaquePointer
|
private var model: OpaquePointer
|
||||||
private var context: OpaquePointer
|
private var context: OpaquePointer
|
||||||
|
private var sampling: OpaquePointer
|
||||||
private var batch: llama_batch
|
private var batch: llama_batch
|
||||||
private var tokens_list: [llama_token]
|
private var tokens_list: [llama_token]
|
||||||
var is_done: Bool = false
|
var is_done: Bool = false
|
||||||
|
@ -42,12 +43,14 @@ actor LlamaContext {
|
||||||
self.tokens_list = []
|
self.tokens_list = []
|
||||||
self.batch = llama_batch_init(512, 0, 1)
|
self.batch = llama_batch_init(512, 0, 1)
|
||||||
self.temporary_invalid_cchars = []
|
self.temporary_invalid_cchars = []
|
||||||
|
self.sampling = llama_get_sampling(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
deinit {
|
deinit {
|
||||||
llama_batch_free(batch)
|
llama_batch_free(batch)
|
||||||
llama_free(context)
|
llama_free(context)
|
||||||
llama_free_model(model)
|
llama_free_model(model)
|
||||||
|
llama_sampling_free(sampling)
|
||||||
llama_backend_free()
|
llama_backend_free()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,7 +159,7 @@ actor LlamaContext {
|
||||||
candidates.withUnsafeMutableBufferPointer() { buffer in
|
candidates.withUnsafeMutableBufferPointer() { buffer in
|
||||||
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
|
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
|
||||||
|
|
||||||
new_token_id = llama_sample_token_greedy(context, &candidates_p)
|
new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p)
|
||||||
}
|
}
|
||||||
|
|
||||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||||
|
|
|
@ -1144,7 +1144,7 @@ extern "C" {
|
||||||
float * mu);
|
float * mu);
|
||||||
|
|
||||||
/// @details Selects the token with the highest probability.
|
/// @details Selects the token with the highest probability.
|
||||||
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
|
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
|
||||||
LLAMA_API llama_token llama_sampling_sample_greedy(
|
LLAMA_API llama_token llama_sampling_sample_greedy(
|
||||||
struct llama_sampling * smpl,
|
struct llama_sampling * smpl,
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue