token count includes ids
This commit is contained in:
parent
0ca814e544
commit
6570a2005b
5 changed files with 26 additions and 9 deletions
11
expose.cpp
11
expose.cpp
|
@ -194,7 +194,7 @@ extern "C"
|
||||||
return gpttype_generate(inputs, output);
|
return gpttype_generate(inputs, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* new_token(int idx) {
|
const char * new_token(int idx) {
|
||||||
if (generated_tokens.size() <= idx || idx < 0) return nullptr;
|
if (generated_tokens.size() <= idx || idx < 0) return nullptr;
|
||||||
|
|
||||||
return generated_tokens[idx].c_str();
|
return generated_tokens[idx].c_str();
|
||||||
|
@ -232,9 +232,14 @@ extern "C"
|
||||||
return gpttype_generate_abort();
|
return gpttype_generate_abort();
|
||||||
}
|
}
|
||||||
|
|
||||||
int token_count(const char * input)
|
static std::vector<int> toks; //just share a static object for token counting
|
||||||
|
token_count_outputs token_count(const char * input)
|
||||||
{
|
{
|
||||||
std::string inputstr = input;
|
std::string inputstr = input;
|
||||||
return gpttype_token_count(inputstr);
|
token_count_outputs output;
|
||||||
|
toks = gpttype_get_token_arr(inputstr);
|
||||||
|
output.count = toks.size();
|
||||||
|
output.ids = toks.data(); //this may be slightly unsafe
|
||||||
|
return output;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
5
expose.h
5
expose.h
|
@ -83,6 +83,11 @@ struct generation_outputs
|
||||||
int status = -1;
|
int status = -1;
|
||||||
char text[32768]; //32kb should be enough for any response
|
char text[32768]; //32kb should be enough for any response
|
||||||
};
|
};
|
||||||
|
struct token_count_outputs
|
||||||
|
{
|
||||||
|
int count = 0;
|
||||||
|
int * ids; //we'll just use shared memory for this one, bit of a hack
|
||||||
|
};
|
||||||
|
|
||||||
extern std::string executable_path;
|
extern std::string executable_path;
|
||||||
extern std::string lora_filename;
|
extern std::string lora_filename;
|
||||||
|
|
|
@ -1390,7 +1390,7 @@ bool gpttype_generate_abort()
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gpttype_token_count(const std::string & input)
|
std::vector<int> gpttype_get_token_arr(const std::string & input)
|
||||||
{
|
{
|
||||||
if(debugmode==1)
|
if(debugmode==1)
|
||||||
{
|
{
|
||||||
|
@ -1403,7 +1403,7 @@ int gpttype_token_count(const std::string & input)
|
||||||
{
|
{
|
||||||
printf("\nTokens Counted: %d\n",tokcount);
|
printf("\nTokens Counted: %d\n",tokcount);
|
||||||
}
|
}
|
||||||
return tokcount;
|
return toks;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & gpttype_get_pending_output()
|
const std::string & gpttype_get_pending_output()
|
||||||
|
|
13
koboldcpp.py
13
koboldcpp.py
|
@ -77,6 +77,10 @@ class generation_outputs(ctypes.Structure):
|
||||||
_fields_ = [("status", ctypes.c_int),
|
_fields_ = [("status", ctypes.c_int),
|
||||||
("text", ctypes.c_char * 32768)]
|
("text", ctypes.c_char * 32768)]
|
||||||
|
|
||||||
|
class token_count_outputs(ctypes.Structure):
|
||||||
|
_fields_ = [("count", ctypes.c_int),
|
||||||
|
("ids", ctypes.POINTER(ctypes.c_int))]
|
||||||
|
|
||||||
handle = None
|
handle = None
|
||||||
|
|
||||||
def getdirpath():
|
def getdirpath():
|
||||||
|
@ -218,7 +222,7 @@ def init_library():
|
||||||
handle.get_total_gens.restype = ctypes.c_int
|
handle.get_total_gens.restype = ctypes.c_int
|
||||||
handle.get_last_stop_reason.restype = ctypes.c_int
|
handle.get_last_stop_reason.restype = ctypes.c_int
|
||||||
handle.abort_generate.restype = ctypes.c_bool
|
handle.abort_generate.restype = ctypes.c_bool
|
||||||
handle.token_count.restype = ctypes.c_int
|
handle.token_count.restype = token_count_outputs
|
||||||
handle.get_pending_output.restype = ctypes.c_char_p
|
handle.get_pending_output.restype = ctypes.c_char_p
|
||||||
|
|
||||||
def load_model(model_filename):
|
def load_model(model_filename):
|
||||||
|
@ -729,8 +733,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
try:
|
try:
|
||||||
genparams = json.loads(body)
|
genparams = json.loads(body)
|
||||||
countprompt = genparams.get('prompt', "")
|
countprompt = genparams.get('prompt', "")
|
||||||
count = handle.token_count(countprompt.encode("UTF-8"))
|
rawcountdata = handle.token_count(countprompt.encode("UTF-8"))
|
||||||
response_body = (json.dumps({"value": count}).encode())
|
countlimit = rawcountdata.count if (rawcountdata.count>=0 and rawcountdata.count<50000) else 0
|
||||||
|
# the above protects the server in case the count limit got corrupted
|
||||||
|
countdata = [rawcountdata.ids[i] for i in range(countlimit)]
|
||||||
|
response_body = (json.dumps({"value": len(countdata),"ids": countdata}).encode())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
utfprint("Count Tokens - Body Error: " + str(e))
|
utfprint("Count Tokens - Body Error: " + str(e))
|
||||||
|
|
|
@ -68,7 +68,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);
|
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);
|
||||||
bool gpttype_generate_abort();
|
bool gpttype_generate_abort();
|
||||||
const std::string & gpttype_get_pending_output();
|
const std::string & gpttype_get_pending_output();
|
||||||
int gpttype_token_count(const std::string & input);
|
std::vector<int> gpttype_get_token_arr(const std::string & input);
|
||||||
|
|
||||||
void timer_start();
|
void timer_start();
|
||||||
double timer_check();
|
double timer_check();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue