Updated Swift and Android bindings to use the new llama_sampling_* refactor from #8643
This commit is contained in:
parent
dbf85440c7
commit
b6c9b539c9
4 changed files with 15 additions and 9 deletions
|
@ -44,6 +44,8 @@ context_params.n_threads = 8
|
|||
context_params.n_threads_batch = 8
|
||||
|
||||
let context = llama_new_context_with_model(model, context_params)
|
||||
let smpl = llama_get_sampling(context)
|
||||
|
||||
guard context != nil else {
|
||||
print("Failed to initialize context")
|
||||
exit(1)
|
||||
|
@ -144,13 +146,13 @@ while n_cur <= n_len {
|
|||
let top_p: Float = 0.9
|
||||
let temp: Float = 0.4
|
||||
|
||||
llama_sample_top_k(context, &candidates_p, top_k, 1)
|
||||
llama_sample_top_p(context, &candidates_p, top_p, 1)
|
||||
llama_sample_temp(context, &candidates_p, temp)
|
||||
llama_sampling_top_k(smpl, &candidates_p, top_k, 1)
|
||||
llama_sampling_top_p(smpl, &candidates_p, top_p, 1)
|
||||
llama_sampling_temp(smpl, &candidates_p, temp)
|
||||
|
||||
let new_token_id = llama_sample_token(context, &candidates_p)
|
||||
let new_token_id = llama_sampling_sample(smpl, &candidates_p)
|
||||
|
||||
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
|
@ -212,7 +214,7 @@ let t_main_end = ggml_time_us()
|
|||
|
||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
|
||||
|
||||
llama_print_timings(context)
|
||||
llama_print_timings(context, smpl, nil)
|
||||
|
||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
let utf8Count = text.utf8.count
|
||||
|
|
|
@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
jobject intvar_ncur
|
||||
) {
|
||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||
const auto sampling = reinterpret_cast<llama_sampling *>(llama_get_sampling(context));
|
||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||
const auto model = llama_get_model(context);
|
||||
|
||||
|
@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
|
||||
const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p);
|
||||
|
||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
|
|
|
@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
|||
actor LlamaContext {
|
||||
private var model: OpaquePointer
|
||||
private var context: OpaquePointer
|
||||
private var sampling: OpaquePointer
|
||||
private var batch: llama_batch
|
||||
private var tokens_list: [llama_token]
|
||||
var is_done: Bool = false
|
||||
|
@ -42,12 +43,14 @@ actor LlamaContext {
|
|||
self.tokens_list = []
|
||||
self.batch = llama_batch_init(512, 0, 1)
|
||||
self.temporary_invalid_cchars = []
|
||||
self.sampling = llama_get_sampling(context)
|
||||
}
|
||||
|
||||
deinit {
|
||||
llama_batch_free(batch)
|
||||
llama_free(context)
|
||||
llama_free_model(model)
|
||||
llama_sampling_free(sampling)
|
||||
llama_backend_free()
|
||||
}
|
||||
|
||||
|
@ -156,7 +159,7 @@ actor LlamaContext {
|
|||
candidates.withUnsafeMutableBufferPointer() { buffer in
|
||||
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
|
||||
|
||||
new_token_id = llama_sample_token_greedy(context, &candidates_p)
|
||||
new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p)
|
||||
}
|
||||
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
|
|
|
@ -1144,7 +1144,7 @@ extern "C" {
|
|||
float * mu);
|
||||
|
||||
/// @details Selects the token with the highest probability.
|
||||
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
|
||||
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
|
||||
LLAMA_API llama_token llama_sampling_sample_greedy(
|
||||
struct llama_sampling * smpl,
|
||||
llama_token_data_array * candidates);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue