From 1fcdcc28b119a6608774d52de905931bd5f8a43d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 25 May 2023 23:07:29 +0200 Subject: [PATCH 1/3] cuda : performance optimizations (#1530) * xor hack * block y dim * loop unrolling * Fixed cmake LLAMA_CUDA_BY option * Removed hipblas compatibility code * Define GGML_CUDA_DMMV_BLOCK_Y if not defined * Fewer iters, more ops per iter * Renamed DMMV X/Y compilation options --- CMakeLists.txt | 52 ++++++++++++----------- Makefile | 12 +++++- ggml-cuda.cu | 110 +++++++++++++++++++++++++++++++------------------ 3 files changed, 110 insertions(+), 64 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39db2e3fc..31c5bd91d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,42 +37,44 @@ endif() # # general -option(LLAMA_STATIC "llama: static link libraries" OFF) -option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) -option(LLAMA_LTO "llama: enable link time optimization" OFF) +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) # debug -option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) -option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) -option(LLAMA_GPROF "llama: enable gprof" OFF) +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) # sanitizers -option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) -option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) -option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) # instruction set specific -option(LLAMA_AVX "llama: enable AVX" ON) -option(LLAMA_AVX2 "llama: enable AVX2" ON) -option(LLAMA_AVX512 "llama: enable AVX512" OFF) -option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) -option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) -option(LLAMA_FMA "llama: enable FMA" ON) +option(LLAMA_AVX "llama: enable AVX" ON) +option(LLAMA_AVX2 "llama: enable AVX2" ON) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ON) # in MSVC F16C is implied with AVX2/AVX512 if (NOT MSVC) - option(LLAMA_F16C "llama: enable F16C" ON) + option(LLAMA_F16C "llama: enable F16C" ON) endif() # 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) -option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) -option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_SERVER "llama: build server example" OFF) +option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" OFF) # # Build info header @@ -184,6 +186,8 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 08e250314..804307b53 100644 --- a/Makefile +++ b/Makefile @@ -133,9 +133,19 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native +ifdef LLAMA_CUDA_DMMV_X + NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) +else + NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 +endif # LLAMA_CUDA_DMMV_X +ifdef LLAMA_CUDA_DMMV_Y + NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y) +else + NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 +endif # LLAMA_CUDA_DMMV_Y ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ -endif +endif # LLAMA_CUBLAS ifdef LLAMA_CLBLAST CFLAGS += -DGGML_USE_CLBLAST CXXFLAGS += -DGGML_USE_CLBLAST diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 35d2e457c..98170a3ae 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -83,9 +83,19 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +#define WARP_SIZE 32 + #define CUDA_MUL_BLOCK_SIZE 256 + #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_CUDA_DMMV_X +#define GGML_CUDA_DMMV_X 32 +#endif +#ifndef GGML_CUDA_DMMV_Y +#define GGML_CUDA_DMMV_Y 1 +#endif static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) dequantize_kernel(vx, ib, iqs, v0, v1); } -template +template static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { - const int row = blockIdx.x; + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - __shared__ float tmp[block_size]; // separate sum for each thread - tmp[tid] = 0; + float tmp = 0; // partial sum for thread in warp - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index - // dequantize - float v0, v1; - dequantize_kernel(vx, ib, iqs, v0, v1); +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; + // dequantize + float v0, v1; + dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + + // matrix multiplication + tmp += v0 * y[iybs + iqs + j/qr + 0]; + tmp += v1 * y[iybs + iqs + j/qr + y_offset]; + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 + } } // sum up partial sums and write back result __syncthreads(); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } + if (tid == 0) { - dst[row] = tmp[0]; + dst[row] = tmp; } } @@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { From 66874d4fbcc7866377246efbcee938e8cc9c7d76 Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Thu, 25 May 2023 20:18:01 -0600 Subject: [PATCH 2/3] Some improvements to loading the session with --prompt-cache (#1550) Improvements to loading the session with `--prompt-cache` in the `main` example. 1. Fix an issue where the `--seed` parameter was ignored when loading a cached prompt. 2. When loading a cached prompt, you previously had to specify the saved prompt (or a prefix of it) again. This pull changes that behavior to default to the prompt that was cached if a prompt wasn't specified by the user. --- examples/main/README.md | 2 +- examples/main/main.cpp | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/main/README.md b/examples/main/README.md index 7c03f92c8..e71ba6173 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -272,7 +272,7 @@ These options help improve the performance and memory usage of the LLaMA models. ### Prompt Caching -- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. +- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation. ### Quantization diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 47b418d97..c7c591537 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -134,8 +134,6 @@ int main(int argc, char ** argv) { return 0; } - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); std::string path_session = params.path_prompt_cache; std::vector session_tokens; @@ -155,6 +153,7 @@ int main(int argc, char ** argv) { return 1; } session_tokens.resize(n_token_count_out); + llama_set_rng_seed(ctx, params.seed); fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); } else { @@ -163,7 +162,16 @@ int main(int argc, char ** argv) { } // tokenize the prompt - auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); + std::vector embd_inp; + + if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { + // Add a space in front of the first character to match OG llama tokenizer behavior + params.prompt.insert(0, 1, ' '); + + embd_inp = ::llama_tokenize(ctx, params.prompt, true); + } else { + embd_inp = session_tokens; + } const int n_ctx = llama_n_ctx(ctx); @@ -181,7 +189,9 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } - if (n_matching_session_tokens >= embd_inp.size()) { + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { + fprintf(stderr, "%s: using full prompt from session file\n", __func__); + } else if (n_matching_session_tokens >= embd_inp.size()) { fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__); } else if (n_matching_session_tokens < (embd_inp.size() / 2)) { fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", From bdbda1b17afb78e8613d03c8210a57fac632397b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 27 May 2023 12:22:05 +0300 Subject: [PATCH 3/3] ggml : sync ggml core (minor additions, e.g. ggml_get_tensor_by_name()) --- ggml.c | 46 +++++++++++++++++++++++++++++++++++++--------- ggml.h | 12 +++++++++++- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/ggml.c b/ggml.c index c0e7ec05c..c24992260 100644 --- a/ggml.c +++ b/ggml.c @@ -3494,7 +3494,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { }; static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated"); -static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { +static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "NONE", "DUP", @@ -3749,6 +3749,9 @@ const char * ggml_type_name(enum ggml_type type) { return GGML_TYPE_NAME[type]; } +const char * ggml_op_name(enum ggml_op op) { + return GGML_OP_NAME[op]; +} size_t ggml_element_size(const struct ggml_tensor * tensor) { return GGML_TYPE_SIZE[tensor->type]; @@ -4017,6 +4020,10 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) return result; } +void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) { + ctx->no_alloc = no_alloc; +} + // IMPORTANT: // when creating "opt" tensors, always save and load the scratch buffer // this is an error prone process, but it is necessary to support inplace @@ -4061,7 +4068,7 @@ struct ggml_tensor * ggml_new_tensor_impl( struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); if (ctx->scratch.data == NULL || data != NULL) { - size_needed += sizeof(struct ggml_tensor); + size_needed += GGML_TENSOR_SIZE; if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", @@ -4077,14 +4084,15 @@ struct ggml_tensor * ggml_new_tensor_impl( }; } else { if (ctx->scratch.offs + size_needed > ctx->scratch.size) { - GGML_PRINT("%s: not enough space in the scratch memory\n", __func__); + GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", + __func__, ctx->scratch.offs + size_needed, ctx->scratch.size); assert(false); return NULL; } - if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) { + if (cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE > ctx->mem_size) { GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size); + __func__, cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE, ctx->mem_size); assert(false); return NULL; } @@ -4093,7 +4101,7 @@ struct ggml_tensor * ggml_new_tensor_impl( *obj_new = (struct ggml_object) { .offs = cur_end + GGML_OBJECT_SIZE, - .size = sizeof(struct ggml_tensor), + .size = GGML_TENSOR_SIZE, .next = NULL, }; @@ -13792,11 +13800,19 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * // reached a leaf node, not part of the gradient graph (e.g. a constant) GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES); + if (strlen(node->name) == 0) { + snprintf(node->name, sizeof(node->name), "leaf_%d", cgraph->n_leafs); + } + cgraph->leafs[cgraph->n_leafs] = node; cgraph->n_leafs++; } else { GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES); + if (strlen(node->name) == 0) { + snprintf(node->name, sizeof(node->name), "node_%d", cgraph->n_nodes); + } + cgraph->nodes[cgraph->n_nodes] = node; cgraph->grads[cgraph->n_nodes] = node->grad; cgraph->n_nodes++; @@ -14510,6 +14526,18 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) { } } +struct ggml_tensor * ggml_get_tensor_by_name(struct ggml_cgraph * cgraph, const char * name) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + if (strcmp(node->name, name) == 0) { + return node; + } + } + + return NULL; +} + void ggml_graph_print(const struct ggml_cgraph * cgraph) { int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; @@ -14527,7 +14555,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", i, node->ne[0], node->ne[1], node->ne[2], - GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, (double) node->perf_cycles / (double) ggml_cycles_per_ms(), (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, (double) node->perf_time_us / 1000.0, @@ -14541,7 +14569,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", i, node->ne[0], node->ne[1], - GGML_OP_LABEL[node->op]); + GGML_OP_NAME[node->op]); } for (int i = 0; i < GGML_OP_COUNT; i++) { @@ -14549,7 +14577,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { continue; } - GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0); + GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_NAME[i], (double) perf_total_per_op_us[i] / 1000.0); } GGML_PRINT("========================================\n"); diff --git a/ggml.h b/ggml.h index c22d93836..0c90f5064 100644 --- a/ggml.h +++ b/ggml.h @@ -198,6 +198,7 @@ #define GGML_MAX_PARAMS 256 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_OPT 4 +#define GGML_MAX_NAME 32 #define GGML_DEFAULT_N_THREADS 4 #define GGML_ASSERT(x) \ @@ -372,11 +373,16 @@ extern "C" { void * data; - char name[32]; + char name[GGML_MAX_NAME]; char padding[16]; }; + static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); + + // use this to compute the memory overhead of a tensor + static const size_t GGML_TENSOR_OVERHEAD = (GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16); + // computation graph struct ggml_cgraph { int n_nodes; @@ -429,6 +435,7 @@ extern "C" { GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float GGML_API const char * ggml_type_name(enum ggml_type type); + GGML_API const char * ggml_op_name (enum ggml_op op); GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); @@ -445,6 +452,7 @@ extern "C" { GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); GGML_API size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); + GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); GGML_API struct ggml_tensor * ggml_new_tensor( struct ggml_context * ctx, @@ -970,6 +978,8 @@ extern "C" { GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); + GGML_API struct ggml_tensor * ggml_get_tensor_by_name(struct ggml_cgraph * cgraph, const char * name); + // print info and performance information for the graph GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);