Merge branch 'master' into concedo_experimental

# Conflicts:
#	CMakeLists.txt
#	Makefile
This commit is contained in:
Concedo 2023-05-27 17:44:14 +08:00
commit 92a0d77712
6 changed files with 145 additions and 55 deletions

View file

@ -139,9 +139,19 @@ ifdef LLAMA_CUBLAS
OBJS += ggml-cuda.o OBJS += ggml-cuda.o
NVCC = nvcc NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
ifdef LLAMA_CUDA_DMMV_X
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
else
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
endif # LLAMA_CUDA_DMMV_X
ifdef LLAMA_CUDA_DMMV_Y
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
else
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
endif # LLAMA_CUDA_DMMV_Y
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif endif # LLAMA_CUBLAS
ifdef LLAMA_GPROF ifdef LLAMA_GPROF
CFLAGS += -pg CFLAGS += -pg

View file

@ -272,7 +272,7 @@ These options help improve the performance and memory usage of the LLaMA models.
### Prompt Caching ### Prompt Caching
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. - `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
### Quantization ### Quantization

View file

@ -134,8 +134,6 @@ int main(int argc, char ** argv) {
return 0; return 0;
} }
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
std::string path_session = params.path_prompt_cache; std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens; std::vector<llama_token> session_tokens;
@ -155,6 +153,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
session_tokens.resize(n_token_count_out); session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else { } else {
@ -163,7 +162,16 @@ int main(int argc, char ** argv) {
} }
// tokenize the prompt // tokenize the prompt
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); std::vector<llama_token> embd_inp;
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
} else {
embd_inp = session_tokens;
}
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
@ -181,7 +189,9 @@ int main(int argc, char ** argv) {
} }
n_matching_session_tokens++; n_matching_session_tokens++;
} }
if (n_matching_session_tokens >= embd_inp.size()) { if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
fprintf(stderr, "%s: using full prompt from session file\n", __func__);
} else if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__); fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) { } else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",

View file

@ -83,9 +83,19 @@ typedef struct {
} block_q8_0; } block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
#define WARP_SIZE 32
#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
#endif
#ifndef GGML_CUDA_DMMV_Y
#define GGML_CUDA_DMMV_Y 1
#endif
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
dequantize_kernel(vx, ib, iqs, v0, v1); dequantize_kernel(vx, ib, iqs, v0, v1);
} }
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
const int row = blockIdx.x; // qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int iter_stride = 2*GGML_CUDA_DMMV_X;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
__shared__ float tmp[block_size]; // separate sum for each thread float tmp = 0; // partial sum for thread in warp
tmp[tid] = 0;
for (int i = 0; i < ncols/block_size; i += 2) { for (int i = 0; i < ncols; i += iter_stride) {
const int col = i*block_size + 2*tid; const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/qk; // block index const int ib = (row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // quant index const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index const int iybs = col - col%qk; // y block start index
// dequantize // processing >2 values per i iter is faster for fast GPUs
float v0, v1; #pragma unroll
dequantize_kernel(vx, ib, iqs, v0, v1); for (int j = 0; j < vals_per_iter; j += 2) {
// process 2 vals per j iter
// matrix multiplication // dequantize
tmp[tid] += v0 * y[iybs + iqs + 0]; float v0, v1;
tmp[tid] += v1 * y[iybs + iqs + y_offset]; dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
// matrix multiplication
tmp += v0 * y[iybs + iqs + j/qr + 0];
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
}
} }
// sum up partial sums and write back result // sum up partial sums and write back result
__syncthreads(); __syncthreads();
for (int s=block_size/2; s>0; s>>=1) { #pragma unroll
if (tid < s) { for (int mask = 16; mask > 0; mask >>= 1) {
tmp[tid] += tmp[tid + s]; tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
__syncthreads();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = tmp[0]; dst[row] = tmp;
} }
} }
@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
} }
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
} }
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {

46
ggml.c
View file

@ -3494,7 +3494,7 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
}; };
static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated"); static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"NONE", "NONE",
"DUP", "DUP",
@ -3749,6 +3749,9 @@ const char * ggml_type_name(enum ggml_type type) {
return GGML_TYPE_NAME[type]; return GGML_TYPE_NAME[type];
} }
const char * ggml_op_name(enum ggml_op op) {
return GGML_OP_NAME[op];
}
size_t ggml_element_size(const struct ggml_tensor * tensor) { size_t ggml_element_size(const struct ggml_tensor * tensor) {
return GGML_TYPE_SIZE[tensor->type]; return GGML_TYPE_SIZE[tensor->type];
@ -4017,6 +4020,10 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
return result; return result;
} }
void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
ctx->no_alloc = no_alloc;
}
// IMPORTANT: // IMPORTANT:
// when creating "opt" tensors, always save and load the scratch buffer // when creating "opt" tensors, always save and load the scratch buffer
// this is an error prone process, but it is necessary to support inplace // this is an error prone process, but it is necessary to support inplace
@ -4061,7 +4068,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
if (ctx->scratch.data == NULL || data != NULL) { if (ctx->scratch.data == NULL || data != NULL) {
size_needed += sizeof(struct ggml_tensor); size_needed += GGML_TENSOR_SIZE;
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
@ -4077,14 +4084,15 @@ struct ggml_tensor * ggml_new_tensor_impl(
}; };
} else { } else {
if (ctx->scratch.offs + size_needed > ctx->scratch.size) { if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
GGML_PRINT("%s: not enough space in the scratch memory\n", __func__); GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
__func__, ctx->scratch.offs + size_needed, ctx->scratch.size);
assert(false); assert(false);
return NULL; return NULL;
} }
if (cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE > ctx->mem_size) { if (cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
__func__, cur_end + sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE, ctx->mem_size); __func__, cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE, ctx->mem_size);
assert(false); assert(false);
return NULL; return NULL;
} }
@ -4093,7 +4101,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
*obj_new = (struct ggml_object) { *obj_new = (struct ggml_object) {
.offs = cur_end + GGML_OBJECT_SIZE, .offs = cur_end + GGML_OBJECT_SIZE,
.size = sizeof(struct ggml_tensor), .size = GGML_TENSOR_SIZE,
.next = NULL, .next = NULL,
}; };
@ -13792,11 +13800,19 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
// reached a leaf node, not part of the gradient graph (e.g. a constant) // reached a leaf node, not part of the gradient graph (e.g. a constant)
GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES); GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES);
if (strlen(node->name) == 0) {
snprintf(node->name, sizeof(node->name), "leaf_%d", cgraph->n_leafs);
}
cgraph->leafs[cgraph->n_leafs] = node; cgraph->leafs[cgraph->n_leafs] = node;
cgraph->n_leafs++; cgraph->n_leafs++;
} else { } else {
GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES); GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES);
if (strlen(node->name) == 0) {
snprintf(node->name, sizeof(node->name), "node_%d", cgraph->n_nodes);
}
cgraph->nodes[cgraph->n_nodes] = node; cgraph->nodes[cgraph->n_nodes] = node;
cgraph->grads[cgraph->n_nodes] = node->grad; cgraph->grads[cgraph->n_nodes] = node->grad;
cgraph->n_nodes++; cgraph->n_nodes++;
@ -14510,6 +14526,18 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
} }
} }
struct ggml_tensor * ggml_get_tensor_by_name(struct ggml_cgraph * cgraph, const char * name) {
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
if (strcmp(node->name, name) == 0) {
return node;
}
}
return NULL;
}
void ggml_graph_print(const struct ggml_cgraph * cgraph) { void ggml_graph_print(const struct ggml_cgraph * cgraph) {
int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0};
@ -14527,7 +14555,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
i, i,
node->ne[0], node->ne[1], node->ne[2], node->ne[0], node->ne[1], node->ne[2],
GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
(double) node->perf_cycles / (double) ggml_cycles_per_ms(), (double) node->perf_cycles / (double) ggml_cycles_per_ms(),
(double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
(double) node->perf_time_us / 1000.0, (double) node->perf_time_us / 1000.0,
@ -14541,7 +14569,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
i, i,
node->ne[0], node->ne[1], node->ne[0], node->ne[1],
GGML_OP_LABEL[node->op]); GGML_OP_NAME[node->op]);
} }
for (int i = 0; i < GGML_OP_COUNT; i++) { for (int i = 0; i < GGML_OP_COUNT; i++) {
@ -14549,7 +14577,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
continue; continue;
} }
GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0); GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_NAME[i], (double) perf_total_per_op_us[i] / 1000.0);
} }
GGML_PRINT("========================================\n"); GGML_PRINT("========================================\n");

12
ggml.h
View file

@ -198,6 +198,7 @@
#define GGML_MAX_PARAMS 256 #define GGML_MAX_PARAMS 256
#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4 #define GGML_MAX_OPT 4
#define GGML_MAX_NAME 32
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
#define GGML_ASSERT(x) \ #define GGML_ASSERT(x) \
@ -372,11 +373,16 @@ extern "C" {
void * data; void * data;
char name[32]; char name[GGML_MAX_NAME];
char padding[16]; char padding[16];
}; };
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
// use this to compute the memory overhead of a tensor
static const size_t GGML_TENSOR_OVERHEAD = (GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16);
// computation graph // computation graph
struct ggml_cgraph { struct ggml_cgraph {
int n_nodes; int n_nodes;
@ -429,6 +435,7 @@ extern "C" {
GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
GGML_API const char * ggml_type_name(enum ggml_type type); GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@ -445,6 +452,7 @@ extern "C" {
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
GGML_API size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch); GGML_API size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
GGML_API struct ggml_tensor * ggml_new_tensor( GGML_API struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -970,6 +978,8 @@ extern "C" {
GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
GGML_API struct ggml_tensor * ggml_get_tensor_by_name(struct ggml_cgraph * cgraph, const char * name);
// print info and performance information for the graph // print info and performance information for the graph
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);