token count includes ids
This commit is contained in:
parent
0ca814e544
commit
6570a2005b
5 changed files with 26 additions and 9 deletions
11
expose.cpp
11
expose.cpp
|
@ -194,7 +194,7 @@ extern "C"
|
|||
return gpttype_generate(inputs, output);
|
||||
}
|
||||
|
||||
const char* new_token(int idx) {
|
||||
const char * new_token(int idx) {
|
||||
if (generated_tokens.size() <= idx || idx < 0) return nullptr;
|
||||
|
||||
return generated_tokens[idx].c_str();
|
||||
|
@ -232,9 +232,14 @@ extern "C"
|
|||
return gpttype_generate_abort();
|
||||
}
|
||||
|
||||
int token_count(const char * input)
|
||||
static std::vector<int> toks; //just share a static object for token counting
|
||||
token_count_outputs token_count(const char * input)
|
||||
{
|
||||
std::string inputstr = input;
|
||||
return gpttype_token_count(inputstr);
|
||||
token_count_outputs output;
|
||||
toks = gpttype_get_token_arr(inputstr);
|
||||
output.count = toks.size();
|
||||
output.ids = toks.data(); //this may be slightly unsafe
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
|
5
expose.h
5
expose.h
|
@ -83,6 +83,11 @@ struct generation_outputs
|
|||
int status = -1;
|
||||
char text[32768]; //32kb should be enough for any response
|
||||
};
|
||||
struct token_count_outputs
|
||||
{
|
||||
int count = 0;
|
||||
int * ids; //we'll just use shared memory for this one, bit of a hack
|
||||
};
|
||||
|
||||
extern std::string executable_path;
|
||||
extern std::string lora_filename;
|
||||
|
|
|
@ -1390,7 +1390,7 @@ bool gpttype_generate_abort()
|
|||
return true;
|
||||
}
|
||||
|
||||
int gpttype_token_count(const std::string & input)
|
||||
std::vector<int> gpttype_get_token_arr(const std::string & input)
|
||||
{
|
||||
if(debugmode==1)
|
||||
{
|
||||
|
@ -1403,7 +1403,7 @@ int gpttype_token_count(const std::string & input)
|
|||
{
|
||||
printf("\nTokens Counted: %d\n",tokcount);
|
||||
}
|
||||
return tokcount;
|
||||
return toks;
|
||||
}
|
||||
|
||||
const std::string & gpttype_get_pending_output()
|
||||
|
|
13
koboldcpp.py
13
koboldcpp.py
|
@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure):
|
|||
_fields_ = [("status", ctypes.c_int),
|
||||
("text", ctypes.c_char * 32768)]
|
||||
|
||||
class token_count_outputs(ctypes.Structure):
|
||||
_fields_ = [("count", ctypes.c_int),
|
||||
("ids", ctypes.POINTER(ctypes.c_int))]
|
||||
|
||||
handle = None
|
||||
|
||||
def getdirpath():
|
||||
|
@ -218,7 +222,7 @@ def init_library():
|
|||
handle.get_total_gens.restype = ctypes.c_int
|
||||
handle.get_last_stop_reason.restype = ctypes.c_int
|
||||
handle.abort_generate.restype = ctypes.c_bool
|
||||
handle.token_count.restype = ctypes.c_int
|
||||
handle.token_count.restype = token_count_outputs
|
||||
handle.get_pending_output.restype = ctypes.c_char_p
|
||||
|
||||
def load_model(model_filename):
|
||||
|
@ -729,8 +733,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
try:
|
||||
genparams = json.loads(body)
|
||||
countprompt = genparams.get('prompt', "")
|
||||
count = handle.token_count(countprompt.encode("UTF-8"))
|
||||
response_body = (json.dumps({"value": count}).encode())
|
||||
rawcountdata = handle.token_count(countprompt.encode("UTF-8"))
|
||||
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
|
||||
# the above protects the server in case the count limit got corrupted
|
||||
countdata = [rawcountdata.ids[i] for i in range(countlimit)]
|
||||
response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode())
|
||||
|
||||
except Exception as e:
|
||||
utfprint("Count Tokens - Body Error: " + str(e))
|
||||
|
|
|
@ -68,7 +68,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);
|
||||
bool gpttype_generate_abort();
|
||||
const std::string & gpttype_get_pending_output();
|
||||
int gpttype_token_count(const std::string & input);
|
||||
std::vector<int> gpttype_get_token_arr(const std::string & input);
|
||||
|
||||
void timer_start();
|
||||
double timer_check();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue