diff --git a/expose.cpp b/expose.cpp index ed6fd4b06..febfcba45 100644 --- a/expose.cpp +++ b/expose.cpp @@ -229,6 +229,9 @@ extern "C" int get_last_token_count() { return last_token_count; } + int get_last_stop_reason() { + return (int)last_stop_reason; + } const char* get_pending_output() { return gpttype_get_pending_output().c_str(); diff --git a/expose.h b/expose.h index 4665cac97..a3114aeb0 100644 --- a/expose.h +++ b/expose.h @@ -14,6 +14,13 @@ enum samplers KCPP_SAMPLER_REP_PEN=6, KCPP_SAMPLER_MAX }; +enum stop_reason +{ + INVALID=-1, + OUT_OF_TOKENS=0, + EOS_TOKEN=1, + CUSTOM_STOPPER=2, +}; struct load_model_inputs { const int threads; @@ -76,3 +83,4 @@ extern bool generation_finished; extern float last_eval_time; extern float last_process_time; extern int last_token_count; +extern stop_reason last_stop_reason; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index bcaead4d4..058ba9351 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -36,6 +36,7 @@ bool generation_finished; float last_process_time = 0; float last_eval_time = 0; int last_token_count = 0; +stop_reason last_stop_reason = stop_reason::INVALID; std::vector generated_tokens; //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) @@ -871,6 +872,7 @@ const std::string & gpttype_get_pending_output() generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output) { concat_output = ""; + last_stop_reason = stop_reason::OUT_OF_TOKENS; stop_sequence.clear(); for(int x=0;x)", matched.c_str()); } + last_stop_reason = stop_reason::CUSTOM_STOPPER; break; } } diff --git a/koboldcpp.py b/koboldcpp.py index d704bc653..e242bbba5 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -165,6 +165,7 @@ def init_library(): handle.get_last_eval_time.restype = ctypes.c_float handle.get_last_process_time.restype = ctypes.c_float handle.get_last_token_count.restype = ctypes.c_int + handle.get_last_stop_reason.restype = ctypes.c_int handle.abort_generate.restype = ctypes.c_bool handle.get_pending_output.restype = ctypes.c_char_p @@ -470,7 +471,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): lastp = handle.get_last_process_time() laste = handle.get_last_eval_time() lastc = handle.get_last_token_count() - response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc}).encode()) + stopreason = handle.get_last_stop_reason() + response_body = (json.dumps({"last_process":lastp,"last_eval":laste,"last_token_count":lastc, "stop_reason":stopreason}).encode()) if response_body is None: self.send_response(404)