server : add server_task_type field to server_task_result
This commit adds a server_task_type field to the server_task_result struct. This field is used to identify the type of the server task. The motivation for adding this is that it might allow us to avoid using dynamic_cast's when checking the type of the server_task_result. For example, this could then be replaced with checks like this: ```c++ GGML_ASSERT(result.get() != nullptr); GGML_ASSERT(result.get()->get_server_task_type() == type); ```
This commit is contained in:
parent
2d2d07618e
commit
6cc2956f3f
1 changed files with 46 additions and 6 deletions
|
@ -68,6 +68,7 @@ enum server_task_type {
|
||||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
|
SERVER_TASK_TYPE_NONE,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum oaicompat_type {
|
enum oaicompat_type {
|
||||||
|
@ -480,6 +481,7 @@ struct result_timings {
|
||||||
struct server_task_result {
|
struct server_task_result {
|
||||||
int id = -1;
|
int id = -1;
|
||||||
int id_slot = -1;
|
int id_slot = -1;
|
||||||
|
server_task_type type;
|
||||||
virtual bool is_error() {
|
virtual bool is_error() {
|
||||||
// only used by server_task_result_error
|
// only used by server_task_result_error
|
||||||
return false;
|
return false;
|
||||||
|
@ -491,6 +493,7 @@ struct server_task_result {
|
||||||
virtual int get_index() {
|
virtual int get_index() {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
virtual server_task_type get_server_task_type() = 0;
|
||||||
virtual json to_json() = 0;
|
virtual json to_json() = 0;
|
||||||
virtual ~server_task_result() = default;
|
virtual ~server_task_result() = default;
|
||||||
};
|
};
|
||||||
|
@ -794,6 +797,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_cmpl_partial : server_task_result {
|
struct server_task_result_cmpl_partial : server_task_result {
|
||||||
|
@ -962,6 +969,10 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
|
|
||||||
return std::vector<json>({ret});
|
return std::vector<json>({ret});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_NONE;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_embd : server_task_result {
|
struct server_task_result_embd : server_task_result {
|
||||||
|
@ -997,6 +1008,10 @@ struct server_task_result_embd : server_task_result {
|
||||||
{"tokens_evaluated", n_tokens},
|
{"tokens_evaluated", n_tokens},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_EMBEDDING;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_rerank : server_task_result {
|
struct server_task_result_rerank : server_task_result {
|
||||||
|
@ -1016,6 +1031,10 @@ struct server_task_result_rerank : server_task_result {
|
||||||
{"tokens_evaluated", n_tokens},
|
{"tokens_evaluated", n_tokens},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_RERANK;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// this function maybe used outside of server_task_result_error
|
// this function maybe used outside of server_task_result_error
|
||||||
|
@ -1071,6 +1090,10 @@ struct server_task_result_error : server_task_result {
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return format_error_response(err_msg, err_type);
|
return format_error_response(err_msg, err_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_NONE;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_metrics : server_task_result {
|
struct server_task_result_metrics : server_task_result {
|
||||||
|
@ -1127,6 +1150,10 @@ struct server_task_result_metrics : server_task_result {
|
||||||
{ "slots", slots_data },
|
{ "slots", slots_data },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_METRICS;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_slot_save_load : server_task_result {
|
struct server_task_result_slot_save_load : server_task_result {
|
||||||
|
@ -1160,6 +1187,10 @@ struct server_task_result_slot_save_load : server_task_result {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_SLOT_SAVE;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_slot_erase : server_task_result {
|
struct server_task_result_slot_erase : server_task_result {
|
||||||
|
@ -1171,12 +1202,20 @@ struct server_task_result_slot_erase : server_task_result {
|
||||||
{ "n_erased", n_erased },
|
{ "n_erased", n_erased },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_SLOT_ERASE;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_apply_lora : server_task_result {
|
struct server_task_result_apply_lora : server_task_result {
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return json {{ "success", true }};
|
return json {{ "success", true }};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
server_task_type get_server_task_type() {
|
||||||
|
return SERVER_TASK_TYPE_NONE;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
|
@ -2751,6 +2790,11 @@ struct server_context {
|
||||||
res->id = task.id;
|
res->id = task.id;
|
||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
} break;
|
} break;
|
||||||
|
case SERVER_TASK_TYPE_NONE:
|
||||||
|
{
|
||||||
|
// do nothing
|
||||||
|
GGML_ASSERT(false && "Invalid task.type (SERVER_TASK_TYPE_NONE)\n");
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3693,12 +3737,8 @@ int main(int argc, char ** argv) {
|
||||||
res_error(res, result->to_json());
|
res_error(res, result->to_json());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
GGML_ASSERT(result.get() != nullptr);
|
||||||
if (type == SERVER_TASK_TYPE_SLOT_SAVE) {
|
GGML_ASSERT(result.get()->get_server_task_type() == type);
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
|
|
||||||
} else if (type == SERVER_TASK_TYPE_SLOT_ERASE) {
|
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
|
|
||||||
}
|
|
||||||
res_ok(res, result->to_json());
|
res_ok(res, result->to_json());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue