sync : migrate examples and llama.cpp to dynamic graphs (wip)

This commit is contained in:
Georgi Gerganov 2023-11-02 18:29:10 +02:00
parent aa7a2c445d
commit e8190705fb
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 50 additions and 41 deletions

View file

@ -9,6 +9,8 @@
#include "ggml.h"
#include "llama.h"
#define LLAMA_TRAIN_MAX_NODES 16384
typedef std::string mt19937_state;
struct train_state {

View file

@ -171,7 +171,8 @@ int main(int argc, char ** argv) {
struct ggml_tensor * m11xm2 = ggml_mul_mat(ctx, m11, m2);
// printf("Creating compute graph\n");
struct ggml_cgraph gf = ggml_build_forward(m11xm2);
struct ggml_cgraph * gf = ggml_new_graph(ctx);
ggml_build_forward_expand(gf, m11xm2);
printf("n_threads=%i\n", benchmark_params.n_threads);
@ -180,9 +181,9 @@ int main(int argc, char ** argv) {
std::vector<uint8_t> work_buffer;
ggml_graph_compute_helper(work_buffer, &gf, benchmark_params.n_threads);
ggml_graph_compute_helper(work_buffer, gf, benchmark_params.n_threads);
TENSOR_DUMP(gf.nodes[0]);
TENSOR_DUMP(gf->nodes[0]);
printf("\n------ Test 2 - Matrix Mult via %s code\n", ggml_type_name(qtype));
@ -200,7 +201,8 @@ int main(int argc, char ** argv) {
struct ggml_tensor * q31 = ggml_mul_mat(ctx, q11, m2);
// printf("Creating compute graph\n");
struct ggml_cgraph gf31 = ggml_build_forward(q31);
struct ggml_cgraph * gf31 = ggml_new_graph(ctx);
ggml_build_forward_expand(gf31, q31);
// Set up a second graph computation to make sure we override the CPU cache lines
// printf("Creating new tensor q12 & Running quantize\n");
@ -211,7 +213,8 @@ int main(int argc, char ** argv) {
struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2);
//printf("Creating compute graph\n");
struct ggml_cgraph gf32 = ggml_build_forward(q32);
struct ggml_cgraph * gf32 = ggml_new_graph(ctx);
ggml_build_forward_expand(gf32, q32);
printf("n_threads=%i\n", benchmark_params.n_threads);
const int dimx = sizex;
@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
// Let's use the F32 result from above as a reference for the quantized multiplication
float sum_of_F32_reference = tensor_sum_elements(gf.nodes[0]);
float sum_of_F32_reference = tensor_sum_elements(gf->nodes[0]);
printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; gigaFLOPS\n");
printf("=====================================================================================\n");
@ -233,7 +236,7 @@ int main(int argc, char ** argv) {
long long int start = ggml_time_us();
//printf("Running ggml_graph_compute\n");
ggml_graph_compute_helper(work_buffer, &gf31, benchmark_params.n_threads);
ggml_graph_compute_helper(work_buffer, gf31, benchmark_params.n_threads);
long long int stop = ggml_time_us();
long long int usec = stop-start;
@ -251,7 +254,7 @@ int main(int argc, char ** argv) {
// Check that the matrix multiplication result is in the right ballpark
// We cannot use the exact value from the F32 multiplication because the quantizuation will be slightly different
float sum_of_Q4_result = tensor_sum_elements(gf31.nodes[0]);
float sum_of_Q4_result = tensor_sum_elements(gf31->nodes[0]);
float delta = std::abs(sum_of_Q4_result - sum_of_F32_reference);
float allowed_delta = (sum_of_F32_reference) / 1000 / 1000; // Let's accept an epsilon of 10^-6
@ -266,7 +269,7 @@ int main(int argc, char ** argv) {
}
// Running a different graph computation to make sure we override the CPU cache lines
ggml_graph_compute_helper(work_buffer, &gf32, benchmark_params.n_threads);
ggml_graph_compute_helper(work_buffer, gf32, benchmark_params.n_threads);
}
printf("\n");
printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));

View file

@ -240,7 +240,7 @@ static struct lora_data * load_lora(struct lora_info * info) {
}
struct ggml_init_params params_ggml;
params_ggml.mem_size = ggml_tensor_overhead() * GGML_MAX_NODES;
params_ggml.mem_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE;
params_ggml.mem_buffer = NULL;
params_ggml.no_alloc = true;
result->ctx = ggml_init(params_ggml);
@ -334,7 +334,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;
struct ggml_init_params params;
params.mem_size = GGML_OBJECT_SIZE + GGML_GRAPH_SIZE + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
params.mem_size = GGML_OBJECT_SIZE + ggml_graph_overhead() + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
params.mem_buffer = NULL;
params.no_alloc = true;
struct ggml_context * ctx = NULL;

View file

@ -1742,8 +1742,8 @@ int main(int argc, char ** argv) {
// context for compute tensors without their data
size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2
+ (GGML_OBJECT_SIZE+ggml_graph_overhead())*(
params.common.use_checkpointing ? 3 : 2
)
);

View file

@ -664,7 +664,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// measure mem requirement and allocate
{
static const size_t tensor_alignment = 32;
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
new_clip->alloc = ggml_allocr_new_measure(tensor_alignment);
clip_image_f32_batch batch;
batch.size = 1;

View file

@ -34,7 +34,7 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
struct ggml_cgraph * gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
// this allocates all Metal resources and memory buffers
auto * ctx_metal = ggml_metal_init(1);
@ -46,13 +46,13 @@ int main(int argc, char ** argv) {
// main
{
struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
*(int32_t *) input->data = 1; // BOS
ggml_metal_set_tensor(ctx_metal, input);
// warmup
ggml_metal_graph_compute(ctx_metal, &gf);
ggml_metal_graph_compute(ctx_metal, gf);
const int n_iter = 16;
@ -60,7 +60,7 @@ int main(int argc, char ** argv) {
// the actual inference happens here
for (int i = 0; i < n_iter; ++i) {
ggml_metal_graph_compute(ctx_metal, &gf);
ggml_metal_graph_compute(ctx_metal, gf);
}
const int64_t t1 = ggml_time_us();
@ -70,7 +70,7 @@ int main(int argc, char ** argv) {
// debug output
{
struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1];
struct ggml_tensor * logits = gf->nodes[gf->n_nodes - 1];
ggml_metal_get_tensor(ctx_metal, logits);
float * ptr = (float *) ggml_get_data(logits);

View file

@ -1109,8 +1109,8 @@ int main(int argc, char ** argv) {
// context for compute tensors without their data
size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
ggml_tensor_overhead()*LLAMA_TRAIN_MAX_NODES*2
+ (GGML_OBJECT_SIZE+ggml_graph_overhead())*(
params.common.use_checkpointing ? 3 : 2
)
);

View file

@ -1,5 +1,6 @@
#import "ggml-metal.h"
#import "ggml-backend-impl.h"
#import "ggml.h"
#import <Foundation/Foundation.h>
@ -23,7 +24,7 @@
#define UNUSED(x) (void)(x)
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
#define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
struct ggml_metal_buffer {
const char * name;
@ -744,6 +745,20 @@ void ggml_metal_graph_compute(
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
struct ggml_tensor * dst = gf->nodes[i];
switch (dst->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{
// noop -> next node
} continue;
default:
{
} break;
}
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0;
@ -797,14 +812,6 @@ void ggml_metal_graph_compute(
//}
switch (dst->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{
// noop
} break;
case GGML_OP_CONCAT:
{
const int64_t nb = ne00;

3
ggml.h
View file

@ -1733,9 +1733,6 @@ extern "C" {
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
// graph allocation in a context
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);

View file

@ -8559,7 +8559,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const size_t elt_size = ggml_element_size(kv_self.k);
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_cgraph gf{};
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
@ -8577,9 +8577,9 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
@ -8687,7 +8687,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
const size_t elt_size = ggml_element_size(kv_self.k);
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_cgraph gf{};
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
kin3d->data = (void *) inp;
@ -8705,9 +8705,9 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
}