setting up RPC + callback on each split completion

1. start rpc server on local instance on two different ports with 5GB
   allocated each.
2. set up another callback on completion of a split. This seems cleaner
   than trying to second-guess which tensor is the boundary of a split.
3. run it with 8B model @ 4bit, observe split_done captured at a reasonable place.

Next step - bring back linear speculation and start speculating on another remote
   instances.
This commit is contained in:
Oleksandr Kuvshynov 2024-04-19 22:13:01 -04:00
parent c3f8d58356
commit d52d193e58
9 changed files with 221 additions and 2 deletions

View file

@ -1,6 +1,6 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml duo \
simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli baby-llama beam-search \
retrieval speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
@ -777,6 +777,10 @@ simple: examples/simple/simple.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
simple: examples/duo/duo.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tokenize: examples/tokenize/tokenize.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -38,6 +38,7 @@ else()
add_subdirectory(retrieval)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(duo)
add_subdirectory(passkey)
add_subdirectory(speculative)
add_subdirectory(lookahead)

View file

@ -0,0 +1,5 @@
set(TARGET duo)
add_executable(${TARGET} duo.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

1
examples/duo/README.md Normal file
View file

@ -0,0 +1 @@
## duo

184
examples/duo/duo.cpp Normal file
View file

@ -0,0 +1,184 @@
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
static void split_done_cb(int split)
{
fprintf(stderr, "split done: %d\n", split);
}
int main(int argc, char ** argv) {
gpt_params params;
if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]);
return 1 ;
}
if (argc >= 2) {
params.model = argv[1];
}
if (argc >= 3) {
params.prompt = argv[2];
}
if (params.prompt.empty()) {
params.prompt = "Hello my name is";
}
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 99;
model_params.rpc_servers = "localhost:50052,localhost:50051";
const int n_len = 128;
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.cb_split_done = split_done_cb;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
const int n_ctx = llama_n_ctx(ctx);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx) {
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
LOG_TEE("%s: either reduce n_len or increase n_ctx\n", __func__);
return 1;
}
// print the prompt token-by-token
for (auto id : tokens_list) {
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
}
fflush(stderr);
llama_batch batch = llama_batch_init(512, 0, 1);
// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); i++) {
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
// main loop
int n_cur = batch.n_tokens;
int n_decode = 0;
const auto t_main_start = ggml_time_us();
// we'll use logits from this position to determine next token
int logit_idx = batch.n_tokens - 1;
while (n_cur <= n_len) {
// sample the next token
{
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, logit_idx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
LOG_TEE("\n");
break;
}
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
fflush(stdout);
// prepare the next batch
llama_batch_clear(batch);
// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
// we still use the 'original' token to sample on next iteration
logit_idx = batch.n_tokens - 1;
n_decode += 1;
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
// remove the cached entries from mock tokens
llama_kv_cache_seq_rm(ctx, 0, n_cur, -1);
}
LOG_TEE("\n");
const auto t_main_end = ggml_time_us();
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
//llama_print_timings(ctx);
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
return 0;
}

View file

@ -1075,6 +1075,8 @@ struct ggml_backend_sched {
ggml_backend_sched_eval_callback callback_eval;
void * callback_eval_user_data;
ggml_backend_sched_split_done_callback callback_split_done;
// align context_buffer to GGML_MEM_ALIGN
#ifdef _MSC_VER
__declspec(align(GGML_MEM_ALIGN))
@ -1708,6 +1710,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]);
}
}
// split finished
if (sched->callback_split_done) {
sched->callback_split_done(i);
}
}
sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
@ -1856,6 +1863,10 @@ void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backe
sched->callback_eval_user_data = user_data;
}
void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback) {
sched->callback_split_done = callback;
}
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
return sched->n_splits;
}

View file

@ -175,6 +175,10 @@ extern "C" {
//
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
// if set will be called when a split is completed computation
// is useful for distributed task orchestraction
typedef void (*ggml_backend_sched_split_done_callback)(int split);
// Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
@ -203,6 +207,9 @@ extern "C" {
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_split_done_callback(ggml_backend_sched_t sched, ggml_backend_sched_split_done_callback callback);
//
// Utils
//

View file

@ -1861,6 +1861,8 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
ggml_backend_sched_split_done_callback cb_split_done;
};
struct llama_layer {
@ -11254,6 +11256,7 @@ static int llama_decode_internal(
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_backend_sched_set_split_done_callback(lctx.sched, lctx.cparams.cb_split_done);
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
@ -15192,6 +15195,7 @@ struct llama_context_params llama_context_default_params() {
/*.defrag_thold =*/ -1.0f,
/*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr,
/*.cb_split_done =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16,
/*.logits_all =*/ false,
@ -15403,6 +15407,7 @@ struct llama_context * llama_new_context_with_model(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.cb_split_done = params.cb_split_done;
auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {

View file

@ -289,7 +289,8 @@ extern "C" {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
ggml_backend_sched_split_done_callback cb_split_done;
enum ggml_type type_k; // data type for K cache
enum ggml_type type_v; // data type for V cache