token count includes ids

This commit is contained in:
Concedo 2023-12-03 15:44:53 +08:00
parent 0ca814e544
commit 6570a2005b
5 changed files with 26 additions and 9 deletions

View file

@ -232,9 +232,14 @@ extern "C"
return gpttype_generate_abort(); return gpttype_generate_abort();
} }
int token_count(const char * input) static std::vector<int> toks; //just share a static object for token counting
token_count_outputs token_count(const char * input)
{ {
std::string inputstr = input; std::string inputstr = input;
return gpttype_token_count(inputstr); token_count_outputs output;
toks = gpttype_get_token_arr(inputstr);
output.count = toks.size();
output.ids = toks.data(); //this may be slightly unsafe
return output;
} }
} }

View file

@ -83,6 +83,11 @@ struct generation_outputs
int status = -1; int status = -1;
char text[32768]; //32kb should be enough for any response char text[32768]; //32kb should be enough for any response
}; };
struct token_count_outputs
{
int count = 0;
int * ids; //we'll just use shared memory for this one, bit of a hack
};
extern std::string executable_path; extern std::string executable_path;
extern std::string lora_filename; extern std::string lora_filename;

View file

@ -1390,7 +1390,7 @@ bool gpttype_generate_abort()
return true; return true;
} }
int gpttype_token_count(const std::string & input) std::vector<int> gpttype_get_token_arr(const std::string & input)
{ {
if(debugmode==1) if(debugmode==1)
{ {
@ -1403,7 +1403,7 @@ int gpttype_token_count(const std::string & input)
{ {
printf("\nTokens Counted: %d\n",tokcount); printf("\nTokens Counted: %d\n",tokcount);
} }
return tokcount; return toks;
} }
const std::string & gpttype_get_pending_output() const std::string & gpttype_get_pending_output()

View file

@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int), _fields_ = [("status", ctypes.c_int),
("text", ctypes.c_char * 32768)] ("text", ctypes.c_char * 32768)]
class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("ids", ctypes.POINTER(ctypes.c_int))]
handle = None handle = None
def getdirpath(): def getdirpath():
@ -218,7 +222,7 @@ def init_library():
handle.get_total_gens.restype = ctypes.c_int handle.get_total_gens.restype = ctypes.c_int
handle.get_last_stop_reason.restype = ctypes.c_int handle.get_last_stop_reason.restype = ctypes.c_int
handle.abort_generate.restype = ctypes.c_bool handle.abort_generate.restype = ctypes.c_bool
handle.token_count.restype = ctypes.c_int handle.token_count.restype = token_count_outputs
handle.get_pending_output.restype = ctypes.c_char_p handle.get_pending_output.restype = ctypes.c_char_p
def load_model(model_filename): def load_model(model_filename):
@ -729,8 +733,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
try: try:
genparams = json.loads(body) genparams = json.loads(body)
countprompt = genparams.get('prompt', "") countprompt = genparams.get('prompt', "")
count = handle.token_count(countprompt.encode("UTF-8")) rawcountdata = handle.token_count(countprompt.encode("UTF-8"))
response_body = (json.dumps({"value": count}).encode()) countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
# the above protects the server in case the count limit got corrupted
countdata = [rawcountdata.ids[i] for i in range(countlimit)]
response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode())
except Exception as e: except Exception as e:
utfprint("Count Tokens - Body Error: " + str(e)) utfprint("Count Tokens - Body Error: " + str(e))

View file

@ -68,7 +68,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output); generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);
bool gpttype_generate_abort(); bool gpttype_generate_abort();
const std::string & gpttype_get_pending_output(); const std::string & gpttype_get_pending_output();
int gpttype_token_count(const std::string & input); std::vector<int> gpttype_get_token_arr(const std::string & input);
void timer_start(); void timer_start();
double timer_check(); double timer_check();