expose some useful info that can be used in statistics of performence

This commit is contained in:
shutup 2023-07-07 11:52:58 +08:00
parent 3d2907d208
commit 1727e652f1
4 changed files with 20 additions and 1 deletions

View file

@ -220,6 +220,14 @@ extern "C"
return generation_finished; return generation_finished;
} }
float get_prompt_eval_time() {
return prompt_eval_time;
}
float get_prompt_process_time() {
return prompt_process_time;
}
const char* get_pending_output() { const char* get_pending_output() {
return gpttype_get_pending_output().c_str(); return gpttype_get_pending_output().c_str();
} }

View file

@ -54,3 +54,5 @@ extern std::string lora_filename;
extern std::string lora_base; extern std::string lora_base;
extern std::vector<std::string> generated_tokens; extern std::vector<std::string> generated_tokens;
extern bool generation_finished; extern bool generation_finished;
extern float prompt_eval_time;
extern float prompt_process_time;

View file

@ -33,6 +33,8 @@ std::string executable_path = "";
std::string lora_filename = ""; std::string lora_filename = "";
std::string lora_base = ""; std::string lora_base = "";
bool generation_finished; bool generation_finished;
float prompt_process_time;
float prompt_eval_time;
std::vector<std::string> generated_tokens; std::vector<std::string> generated_tokens;
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
@ -807,6 +809,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
bool stream_sse = inputs.stream_sse; bool stream_sse = inputs.stream_sse;
generation_finished = false; // Set current generation status generation_finished = false; // Set current generation status
prompt_eval_time = 0;
prompt_process_time = 0;
generated_tokens.clear(); // New Generation, new tokens generated_tokens.clear(); // New Generation, new tokens
if (params.repeat_last_n < 1) if (params.repeat_last_n < 1)
@ -1327,6 +1331,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
generation_finished = true; generation_finished = true;
prompt_eval_time = pt2;
prompt_process_time = pt1;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output; return output;

View file

@ -151,6 +151,8 @@ def init_library():
handle.new_token.argtypes = [ctypes.c_int] handle.new_token.argtypes = [ctypes.c_int]
handle.get_stream_count.restype = ctypes.c_int handle.get_stream_count.restype = ctypes.c_int
handle.has_finished.restype = ctypes.c_bool handle.has_finished.restype = ctypes.c_bool
handle.get_prompt_eval_time.restype = ctypes.c_float
handle.get_prompt_process_time.restype = ctypes.c_float
handle.abort_generate.restype = ctypes.c_bool handle.abort_generate.restype = ctypes.c_bool
handle.get_pending_output.restype = ctypes.c_char_p handle.get_pending_output.restype = ctypes.c_char_p
@ -485,7 +487,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
newprompt = fullprompt newprompt = fullprompt
gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
gen['prompt_process_time'] = handle.get_prompt_process_time()
gen['prompt_eval_time'] = handle.get_prompt_eval_time()
try: try:
self.send_response(200) self.send_response(200)
self.end_headers() self.end_headers()