server : add server_task_type field to server_task_result

This commit adds a server_task_type field to the server_task_result
struct. This field is used to identify the type of the server task.

The motivation for adding this is that it might allow us to avoid using
dynamic_cast's when checking the type of the server_task_result.
For example, this could then be replaced with checks like this:
```c++
        GGML_ASSERT(result.get() != nullptr);
        GGML_ASSERT(result.get()->get_server_task_type() == type);
```
This commit is contained in:
Daniel Bevenius 2025-01-31 14:14:39 +01:00
parent 2d2d07618e
commit 6cc2956f3f

View file

@ -68,6 +68,7 @@ enum server_task_type {
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA,
SERVER_TASK_TYPE_NONE,
};
enum oaicompat_type {
@ -480,6 +481,7 @@ struct result_timings {
struct server_task_result {
int id = -1;
int id_slot = -1;
server_task_type type;
virtual bool is_error() {
// only used by server_task_result_error
return false;
@ -491,6 +493,7 @@ struct server_task_result {
virtual int get_index() {
return -1;
}
virtual server_task_type get_server_task_type() = 0;
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
@ -794,6 +797,10 @@ struct server_task_result_cmpl_final : server_task_result {
return ret;
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_COMPLETION;
}
};
struct server_task_result_cmpl_partial : server_task_result {
@ -962,6 +969,10 @@ struct server_task_result_cmpl_partial : server_task_result {
return std::vector<json>({ret});
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_NONE;
}
};
struct server_task_result_embd : server_task_result {
@ -997,6 +1008,10 @@ struct server_task_result_embd : server_task_result {
{"tokens_evaluated", n_tokens},
};
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_EMBEDDING;
}
};
struct server_task_result_rerank : server_task_result {
@ -1016,6 +1031,10 @@ struct server_task_result_rerank : server_task_result {
{"tokens_evaluated", n_tokens},
};
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_RERANK;
}
};
// this function maybe used outside of server_task_result_error
@ -1071,6 +1090,10 @@ struct server_task_result_error : server_task_result {
virtual json to_json() override {
return format_error_response(err_msg, err_type);
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_NONE;
}
};
struct server_task_result_metrics : server_task_result {
@ -1127,6 +1150,10 @@ struct server_task_result_metrics : server_task_result {
{ "slots", slots_data },
};
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_METRICS;
}
};
struct server_task_result_slot_save_load : server_task_result {
@ -1160,6 +1187,10 @@ struct server_task_result_slot_save_load : server_task_result {
};
}
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_SLOT_SAVE;
}
};
struct server_task_result_slot_erase : server_task_result {
@ -1171,12 +1202,20 @@ struct server_task_result_slot_erase : server_task_result {
{ "n_erased", n_erased },
};
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_SLOT_ERASE;
}
};
struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override {
return json {{ "success", true }};
}
server_task_type get_server_task_type() {
return SERVER_TASK_TYPE_NONE;
}
};
struct server_slot {
@ -2751,6 +2790,11 @@ struct server_context {
res->id = task.id;
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_NONE:
{
// do nothing
GGML_ASSERT(false && "Invalid task.type (SERVER_TASK_TYPE_NONE)\n");
} break;
}
}
@ -3693,12 +3737,8 @@ int main(int argc, char ** argv) {
res_error(res, result->to_json());
return;
}
if (type == SERVER_TASK_TYPE_SLOT_SAVE) {
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
} else if (type == SERVER_TASK_TYPE_SLOT_ERASE) {
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
}
GGML_ASSERT(result.get() != nullptr);
GGML_ASSERT(result.get()->get_server_task_type() == type);
res_ok(res, result->to_json());
};