server : add server_task_type field to server_task_result
This commit adds a server_task_type field to the server_task_result struct. This field is used to identify the type of the server task. The motivation for adding this is that it might allow us to avoid using dynamic_cast's when checking the type of the server_task_result. For example, this could then be replaced with checks like this: ```c++ GGML_ASSERT(result.get() != nullptr); GGML_ASSERT(result.get()->get_server_task_type() == type); ```
This commit is contained in:
parent
2d2d07618e
commit
6cc2956f3f
1 changed files with 46 additions and 6 deletions
|
@ -68,6 +68,7 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||
SERVER_TASK_TYPE_SET_LORA,
|
||||
SERVER_TASK_TYPE_NONE,
|
||||
};
|
||||
|
||||
enum oaicompat_type {
|
||||
|
@ -480,6 +481,7 @@ struct result_timings {
|
|||
struct server_task_result {
|
||||
int id = -1;
|
||||
int id_slot = -1;
|
||||
server_task_type type;
|
||||
virtual bool is_error() {
|
||||
// only used by server_task_result_error
|
||||
return false;
|
||||
|
@ -491,6 +493,7 @@ struct server_task_result {
|
|||
virtual int get_index() {
|
||||
return -1;
|
||||
}
|
||||
virtual server_task_type get_server_task_type() = 0;
|
||||
virtual json to_json() = 0;
|
||||
virtual ~server_task_result() = default;
|
||||
};
|
||||
|
@ -794,6 +797,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
return ret;
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_COMPLETION;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_partial : server_task_result {
|
||||
|
@ -962,6 +969,10 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
|
||||
return std::vector<json>({ret});
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_NONE;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_embd : server_task_result {
|
||||
|
@ -997,6 +1008,10 @@ struct server_task_result_embd : server_task_result {
|
|||
{"tokens_evaluated", n_tokens},
|
||||
};
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_EMBEDDING;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_rerank : server_task_result {
|
||||
|
@ -1016,6 +1031,10 @@ struct server_task_result_rerank : server_task_result {
|
|||
{"tokens_evaluated", n_tokens},
|
||||
};
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_RERANK;
|
||||
}
|
||||
};
|
||||
|
||||
// this function maybe used outside of server_task_result_error
|
||||
|
@ -1071,6 +1090,10 @@ struct server_task_result_error : server_task_result {
|
|||
virtual json to_json() override {
|
||||
return format_error_response(err_msg, err_type);
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_NONE;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_metrics : server_task_result {
|
||||
|
@ -1127,6 +1150,10 @@ struct server_task_result_metrics : server_task_result {
|
|||
{ "slots", slots_data },
|
||||
};
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_METRICS;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_slot_save_load : server_task_result {
|
||||
|
@ -1160,6 +1187,10 @@ struct server_task_result_slot_save_load : server_task_result {
|
|||
};
|
||||
}
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_SLOT_SAVE;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_slot_erase : server_task_result {
|
||||
|
@ -1171,12 +1202,20 @@ struct server_task_result_slot_erase : server_task_result {
|
|||
{ "n_erased", n_erased },
|
||||
};
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_SLOT_ERASE;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_apply_lora : server_task_result {
|
||||
virtual json to_json() override {
|
||||
return json {{ "success", true }};
|
||||
}
|
||||
|
||||
server_task_type get_server_task_type() {
|
||||
return SERVER_TASK_TYPE_NONE;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
|
@ -2751,6 +2790,11 @@ struct server_context {
|
|||
res->id = task.id;
|
||||
queue_results.send(std::move(res));
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_NONE:
|
||||
{
|
||||
// do nothing
|
||||
GGML_ASSERT(false && "Invalid task.type (SERVER_TASK_TYPE_NONE)\n");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3693,12 +3737,8 @@ int main(int argc, char ** argv) {
|
|||
res_error(res, result->to_json());
|
||||
return;
|
||||
}
|
||||
|
||||
if (type == SERVER_TASK_TYPE_SLOT_SAVE) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
|
||||
} else if (type == SERVER_TASK_TYPE_SLOT_ERASE) {
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
|
||||
}
|
||||
GGML_ASSERT(result.get() != nullptr);
|
||||
GGML_ASSERT(result.get()->get_server_task_type() == type);
|
||||
res_ok(res, result->to_json());
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue