From 62ce167b22580e4b697be2e31e4f61a53fd10475 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 1 Apr 2023 13:02:10 -0400 Subject: [PATCH] Update low level api example --- examples/llama_cpp.py | 35 +++++++++++++++++++++++++++-- examples/low_level_api_llama_cpp.py | 10 ++++----- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/examples/llama_cpp.py b/examples/llama_cpp.py index 156139f71..03232560f 100644 --- a/examples/llama_cpp.py +++ b/examples/llama_cpp.py @@ -1,5 +1,4 @@ import ctypes - from ctypes import ( c_int, c_float, @@ -8,7 +7,9 @@ from ctypes import ( c_bool, POINTER, Structure, - Array + Array, + c_uint8, + c_size_t ) import pathlib from itertools import chain @@ -109,6 +110,36 @@ def llama_model_quantize( _lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int] _lib.llama_model_quantize.restype = c_int +# Returns the KV cache that will contain the context for the +# ongoing prediction with the model. +def llama_get_kv_cache(ctx: llama_context_p): + return _lib.llama_get_kv_cache(ctx) + +_lib.llama_get_kv_cache.argtypes = [llama_context_p] +_lib.llama_get_kv_cache.restype = POINTER(c_uint8) + +# Returns the size of the KV cache +def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t: + return _lib.llama_get_kv_cache_size(ctx) + +_lib.llama_get_kv_cache_size.argtypes = [llama_context_p] +_lib.llama_get_kv_cache_size.restype = c_size_t + +# Returns the number of tokens in the KV cache +def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int: + return _lib.llama_get_kv_cache_token_count(ctx) + +_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p] +_lib.llama_get_kv_cache_token_count.restype = c_int + + +# Sets the KV cache containing the current context for the model +def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int): + return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count) + +_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int] +_lib.llama_set_kv_cache.restype = None + # Run the llama inference to obtain the logits and probabilities for the next token. # tokens + n_tokens is the provided batch of new tokens to process diff --git a/examples/low_level_api_llama_cpp.py b/examples/low_level_api_llama_cpp.py index 4a888c355..2a639aad5 100644 --- a/examples/low_level_api_llama_cpp.py +++ b/examples/low_level_api_llama_cpp.py @@ -35,7 +35,7 @@ remaining_tokens = n_predict embd = [] last_n_size = 64 -last_n_tokens = [0] * last_n_size +last_n_tokens_data = [0] * last_n_size n_batch = 24 while remaining_tokens > 0: @@ -49,21 +49,21 @@ while remaining_tokens > 0: if len(embd_inp) <= input_consumed: id = llama_cpp.llama_sample_top_p_top_k( ctx, - (llama_cpp.c_int * len(last_n_tokens))(*last_n_tokens), - len(last_n_tokens), + (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data), + len(last_n_tokens_data), 40, 0.8, 0.2, 1.0 / 0.85, ) - last_n_tokens = last_n_tokens[1:] + [id] + last_n_tokens_data = last_n_tokens_data[1:] + [id] embd.append(id) input_noecho = False remaining_tokens -= 1 else: while len(embd_inp) > input_consumed: embd.append(embd_inp[input_consumed]) - last_n_tokens = last_n_tokens[1:] + [embd_inp[input_consumed]] + last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]] input_consumed += 1 if len(embd) >= n_batch: break