llama : factor out plan stuff into a helper function

This commit is contained in:
Georgi Gerganov 2023-07-06 21:12:25 +03:00
parent a67404e749
commit 2d3a5252f9
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -79,6 +79,25 @@ void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
(void) tensor; (void) tensor;
} }
//
// ggml helpers
//
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}
ggml_graph_compute(graph, &plan);
}
//
// memory sizes
//
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0() static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
@ -761,7 +780,6 @@ struct llama_model_loader {
}; };
// //
// kv cache // kv cache
// //
@ -1623,12 +1641,7 @@ static bool llama_eval_internal(
#endif #endif
if (call_ggml_graph_compute) { if (call_ggml_graph_compute) {
ggml_cplan pf = ggml_graph_plan(&gf, n_threads); ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
if (pf.work_size > 0) {
lctx.work_buffer.resize(pf.work_size);
pf.work_data = lctx.work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
} }
if (cgraph_fname) { if (cgraph_fname) {
@ -2983,14 +2996,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
struct ggml_cgraph gf = ggml_build_forward(r); struct ggml_cgraph gf = ggml_build_forward(r);
{ ggml_graph_compute_helper(work_buffer, &gf, n_threads);
ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
if (pf.work_size > 0) {
work_buffer.resize(pf.work_size);
pf.work_data = work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
}
// we won't need these tensors again, reset the context to save memory // we won't need these tensors again, reset the context to save memory
ggml_free(lora_ctx); ggml_free(lora_ctx);
@ -3162,15 +3168,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
{
ggml_cplan pf = ggml_graph_plan(&gf, /*n_threads*/ 1);
if (pf.work_size > 0) {
ctx->work_buffer.resize(pf.work_size);
pf.work_data = ctx->work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
}
ggml_free(cpy_ctx); ggml_free(cpy_ctx);
} }
@ -3275,15 +3273,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
{
ggml_cplan pf = ggml_graph_plan(&gf, /*n_threads*/ 1);
if (pf.work_size > 0) {
ctx->work_buffer.resize(pf.work_size);
pf.work_data = ctx->work_buffer.data();
}
ggml_graph_compute(&gf, &pf);
}
ggml_free(cpy_ctx); ggml_free(cpy_ctx);
} }