Update low level api example

This commit is contained in:
Andrei Betlen 2023-04-01 13:02:10 -04:00 committed by Don Mahurin
parent a71cda6546
commit 62ce167b22
2 changed files with 38 additions and 7 deletions

View file

@ -1,5 +1,4 @@
import ctypes import ctypes
from ctypes import ( from ctypes import (
c_int, c_int,
c_float, c_float,
@ -8,7 +7,9 @@ from ctypes import (
c_bool, c_bool,
POINTER, POINTER,
Structure, Structure,
Array Array,
c_uint8,
c_size_t
) )
import pathlib import pathlib
from itertools import chain from itertools import chain
@ -109,6 +110,36 @@ def llama_model_quantize(
_lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int] _lib.llama_model_quantize.argtypes = [c_char_p, c_char_p, c_int]
_lib.llama_model_quantize.restype = c_int _lib.llama_model_quantize.restype = c_int
# Returns the KV cache that will contain the context for the
# ongoing prediction with the model.
def llama_get_kv_cache(ctx: llama_context_p):
return _lib.llama_get_kv_cache(ctx)
_lib.llama_get_kv_cache.argtypes = [llama_context_p]
_lib.llama_get_kv_cache.restype = POINTER(c_uint8)
# Returns the size of the KV cache
def llama_get_kv_cache_size(ctx: llama_context_p) -> c_size_t:
return _lib.llama_get_kv_cache_size(ctx)
_lib.llama_get_kv_cache_size.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_size.restype = c_size_t
# Returns the number of tokens in the KV cache
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
return _lib.llama_get_kv_cache_token_count(ctx)
_lib.llama_get_kv_cache_token_count.argtypes = [llama_context_p]
_lib.llama_get_kv_cache_token_count.restype = c_int
# Sets the KV cache containing the current context for the model
def llama_set_kv_cache(ctx: llama_context_p, kv_cache, n_size: c_size_t, n_token_count: c_int):
return _lib.llama_set_kv_cache(ctx, kv_cache, n_size, n_token_count)
_lib.llama_set_kv_cache.argtypes = [llama_context_p, POINTER(c_uint8), c_size_t, c_int]
_lib.llama_set_kv_cache.restype = None
# Run the llama inference to obtain the logits and probabilities for the next token. # Run the llama inference to obtain the logits and probabilities for the next token.
# tokens + n_tokens is the provided batch of new tokens to process # tokens + n_tokens is the provided batch of new tokens to process

View file

@ -35,7 +35,7 @@ remaining_tokens = n_predict
embd = [] embd = []
last_n_size = 64 last_n_size = 64
last_n_tokens = [0] * last_n_size last_n_tokens_data = [0] * last_n_size
n_batch = 24 n_batch = 24
while remaining_tokens > 0: while remaining_tokens > 0:
@ -49,21 +49,21 @@ while remaining_tokens > 0:
if len(embd_inp) <= input_consumed: if len(embd_inp) <= input_consumed:
id = llama_cpp.llama_sample_top_p_top_k( id = llama_cpp.llama_sample_top_p_top_k(
ctx, ctx,
(llama_cpp.c_int * len(last_n_tokens))(*last_n_tokens), (llama_cpp.c_int * len(last_n_tokens_data))(*last_n_tokens_data),
len(last_n_tokens), len(last_n_tokens_data),
40, 40,
0.8, 0.8,
0.2, 0.2,
1.0 / 0.85, 1.0 / 0.85,
) )
last_n_tokens = last_n_tokens[1:] + [id] last_n_tokens_data = last_n_tokens_data[1:] + [id]
embd.append(id) embd.append(id)
input_noecho = False input_noecho = False
remaining_tokens -= 1 remaining_tokens -= 1
else: else:
while len(embd_inp) > input_consumed: while len(embd_inp) > input_consumed:
embd.append(embd_inp[input_consumed]) embd.append(embd_inp[input_consumed])
last_n_tokens = last_n_tokens[1:] + [embd_inp[input_consumed]] last_n_tokens_data = last_n_tokens_data[1:] + [embd_inp[input_consumed]]
input_consumed += 1 input_consumed += 1
if len(embd) >= n_batch: if len(embd) >= n_batch:
break break