server : add server_task_type field to server_task_result

This commit adds a server_task_type field to the server_task_result
struct. This field is used to identify the type of the server task.

The motivation for adding this is that it might allow us to avoid using
dynamic_cast's when checking the type of the server_task_result.
For example, this could then be replaced with checks like this:
```c++
        GGML_ASSERT(result.get() != nullptr);
        GGML_ASSERT(result.get()->get_server_task_type() == type);
```
This commit is contained in:
Daniel Bevenius 2025-01-31 14:14:39 +01:00
parent 2d2d07618e
commit 6cc2956f3f

View file

@ -68,6 +68,7 @@ enum server_task_type {
SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE, SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA, SERVER_TASK_TYPE_SET_LORA,
SERVER_TASK_TYPE_NONE,
}; };
enum oaicompat_type { enum oaicompat_type {
@ -480,6 +481,7 @@ struct result_timings {
struct server_task_result { struct server_task_result {
int id = -1; int id = -1;
int id_slot = -1; int id_slot = -1;
server_task_type type;
virtual bool is_error() { virtual bool is_error() {
// only used by server_task_result_error // only used by server_task_result_error
return false; return false;
@ -491,6 +493,7 @@ struct server_task_result {
virtual int get_index() { virtual int get_index() {
return -1; return -1;
} }
virtual server_task_type get_server_task_type() = 0;
virtual json to_json() = 0; virtual json to_json() = 0;
virtual ~server_task_result() = default; virtual ~server_task_result() = default;
}; };
@ -794,6 +797,10 @@ struct server_task_result_cmpl_final : server_task_result {
return ret; return ret;
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_COMPLETION;
}
}; };
struct server_task_result_cmpl_partial : server_task_result { struct server_task_result_cmpl_partial : server_task_result {
@ -962,6 +969,10 @@ struct server_task_result_cmpl_partial : server_task_result {
return std::vector<json>({ret}); return std::vector<json>({ret});
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_NONE;
}
}; };
struct server_task_result_embd : server_task_result { struct server_task_result_embd : server_task_result {
@ -997,6 +1008,10 @@ struct server_task_result_embd : server_task_result {
{"tokens_evaluated", n_tokens}, {"tokens_evaluated", n_tokens},
}; };
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_EMBEDDING;
}
}; };
struct server_task_result_rerank : server_task_result { struct server_task_result_rerank : server_task_result {
@ -1016,6 +1031,10 @@ struct server_task_result_rerank : server_task_result {
{"tokens_evaluated", n_tokens}, {"tokens_evaluated", n_tokens},
}; };
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_RERANK;
}
}; };
// this function maybe used outside of server_task_result_error // this function maybe used outside of server_task_result_error
@ -1071,6 +1090,10 @@ struct server_task_result_error : server_task_result {
virtual json to_json() override { virtual json to_json() override {
return format_error_response(err_msg, err_type); return format_error_response(err_msg, err_type);
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_NONE;
}
}; };
struct server_task_result_metrics : server_task_result { struct server_task_result_metrics : server_task_result {
@ -1127,6 +1150,10 @@ struct server_task_result_metrics : server_task_result {
{ "slots", slots_data }, { "slots", slots_data },
}; };
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_METRICS;
}
}; };
struct server_task_result_slot_save_load : server_task_result { struct server_task_result_slot_save_load : server_task_result {
@ -1160,6 +1187,10 @@ struct server_task_result_slot_save_load : server_task_result {
}; };
} }
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_SLOT_SAVE;
}
}; };
struct server_task_result_slot_erase : server_task_result { struct server_task_result_slot_erase : server_task_result {
@ -1171,12 +1202,20 @@ struct server_task_result_slot_erase : server_task_result {
{ "n_erased", n_erased }, { "n_erased", n_erased },
}; };
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_SLOT_ERASE;
}
}; };
struct server_task_result_apply_lora : server_task_result { struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override { virtual json to_json() override {
return json {{ "success", true }}; return json {{ "success", true }};
} }
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_NONE;
}
}; };
struct server_slot { struct server_slot {
@ -2751,6 +2790,11 @@ struct server_context {
res->id = task.id; res->id = task.id;
queue_results.send(std::move(res)); queue_results.send(std::move(res));
} break; } break;
case SERVER_TASK_TYPE_NONE:
{
// do nothing
GGML_ASSERT(false && "Invalid task.type (SERVER_TASK_TYPE_NONE)\n");
} break;
} }
} }
@ -3693,12 +3737,8 @@ int main(int argc, char ** argv) {
res_error(res, result->to_json()); res_error(res, result->to_json());
return; return;
} }
GGML_ASSERT(result.get() != nullptr);
if (type == SERVER_TASK_TYPE_SLOT_SAVE) { GGML_ASSERT(result.get()->get_server_task_type() == type);
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
} else if (type == SERVER_TASK_TYPE_SLOT_ERASE) {
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
}
res_ok(res, result->to_json()); res_ok(res, result->to_json());
}; };