ggml : add graph tensor allocator

This commit is contained in:
slaren 2023-07-26 17:13:58 +02:00
parent 1a941869cb
commit 768ecfcc28
7 changed files with 708 additions and 43 deletions

View file

@ -497,6 +497,8 @@ endif()
add_library(ggml OBJECT add_library(ggml OBJECT
ggml.c ggml.c
ggml.h ggml.h
ggml-alloc.c
ggml-alloc.h
${GGML_SOURCES_CUDA} ${GGML_SOURCES_CUDA}
${GGML_SOURCES_OPENCL} ${GGML_SOURCES_OPENCL}
${GGML_SOURCES_METAL} ${GGML_SOURCES_METAL}

View file

@ -318,7 +318,12 @@ $(info )
ggml.o: ggml.c ggml.h ggml-cuda.h ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@ $(CC) $(CFLAGS) -c $< -o $@
llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
$(CC) $(CFLAGS) -c $< -o $@
OBJS += ggml-alloc.o
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
common.o: examples/common.cpp examples/common.h common.o: examples/common.cpp examples/common.h

488
ggml-alloc.c Normal file
View file

@ -0,0 +1,488 @@
#include "ggml-alloc.h"
#include "ggml.h"
#include <assert.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define UNUSED(x) (void)(x)
#define MAX(a, b) ((a) > (b) ? (a) : (b))
//#define GGML_ALLOCATOR_DEBUG
//#define AT_PRINTF printf
#define AT_PRINTF(...) ((void)0)
// TODO: GGML_PAD ?
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
assert(alignment && !(alignment & (alignment - 1))); // power of 2
size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
return offset + align;
}
struct free_block {
void * addr;
size_t size;
};
#define MAX_FREE_BLOCKS 128
struct ggml_allocator {
void * data;
size_t size;
size_t alignment;
int n_free_blocks;
struct free_block free_blocks[MAX_FREE_BLOCKS];
size_t max_size;
bool measure;
#ifdef GGML_ALLOCATOR_DEBUG
struct ggml_tensor * allocated_tensors[1024];
#endif
};
#ifdef GGML_ALLOCATOR_DEBUG
static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == NULL) {
alloc->allocated_tensors[i] = tensor;
return;
}
}
GGML_ASSERT(!"out of allocated_tensors");
}
static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == tensor ||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
alloc->allocated_tensors[i] = NULL;
return;
}
}
printf("tried to free tensor %s not found\n", tensor->name);
GGML_ASSERT(!"tensor not found");
}
#endif
static size_t ggml_allocator_get_alloc_size(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
return ggml_nbytes(tensor);
UNUSED(alloc);
}
void ggml_allocator_alloc_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
size_t max_avail = 0;
// find the best fitting free block
int best_fit_block = -1;
size_t best_fit_size = SIZE_MAX;
for (int i = 0; i < alloc->n_free_blocks; i++) {
struct free_block * block = &alloc->free_blocks[i];
max_avail = MAX(max_avail, block->size);
if (block->size >= size && block->size <= best_fit_size) {
best_fit_block = i;
best_fit_size = block->size;
}
}
AT_PRINTF("block %d\n", best_fit_block);
if (best_fit_block == -1) {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
__func__, size, max_avail);
GGML_ASSERT(!"not enough space in the buffer");
return;
}
struct free_block * block = &alloc->free_blocks[best_fit_block];
void * addr = block->addr;
block->addr = (char*)block->addr + size;
block->size -= size;
if (block->size == 0) {
// remove block if empty
alloc->n_free_blocks--;
for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
alloc->free_blocks[j] = alloc->free_blocks[j+1];
}
}
tensor->data = addr;
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
size_t cur_max = (char*)addr - (char*)alloc->data + size;
if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i]) {
printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
}
}
printf("\n");
}
#endif
alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
}
// this is a very naive implementation, but for our case the number of free blocks should be very small
static void ggml_allocator_free_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
void * ptr = tensor->data;
if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
// the tensor was not allocated in this buffer
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
// the easiest way to deal with this is just to ignore it
return;
}
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
#ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
#endif
// see if we can merge with an existing block
for (int i = 0; i < alloc->n_free_blocks; i++) {
struct free_block * block = &alloc->free_blocks[i];
// check if ptr is at the end of the block
if ((char*)block->addr + block->size == ptr) {
block->size += size;
// check if we can merge with the next block
if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
block->size += alloc->free_blocks[i+1].size;
alloc->n_free_blocks--;
for (int j = i+1; j < alloc->n_free_blocks; j++) {
alloc->free_blocks[j] = alloc->free_blocks[j+1];
}
}
return;
}
// check if ptr is at the beginning of the block
if ((char*)ptr + size == block->addr) {
block->addr = ptr;
block->size += size;
// check if we can merge with the previous block
if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
alloc->free_blocks[i-1].size += block->size;
alloc->n_free_blocks--;
for (int j = i; j < alloc->n_free_blocks; j++) {
alloc->free_blocks[j] = alloc->free_blocks[j+1];
}
}
return;
}
}
// otherwise, add a new block
GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
// insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
int insert_pos = 0;
while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
insert_pos++;
}
// shift all blocks from insert_pos onward to make room for the new block
for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
alloc->free_blocks[i] = alloc->free_blocks[i-1];
}
// insert the new block
alloc->free_blocks[insert_pos].addr = ptr;
alloc->free_blocks[insert_pos].size = size;
alloc->n_free_blocks++;
}
void ggml_allocator_reset(struct ggml_allocator * alloc) {
alloc->n_free_blocks = 1;
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
alloc->free_blocks[0].size = alloc->size - align_offset;
}
struct ggml_allocator * ggml_allocator_new(void * data, size_t size, size_t alignment) {
struct ggml_allocator * alloc = (struct ggml_allocator *)malloc(sizeof(struct ggml_allocator) /* + n_free_blocks * sizeof(struct free_block) */);
*alloc = (struct ggml_allocator){
/*.data = */ data,
/*.size = */ size,
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0},
#endif
};
ggml_allocator_reset(alloc);
return alloc;
}
// address and size of the buffer when measuring
// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
struct ggml_allocator * ggml_allocator_new_measure(size_t alignment) {
struct ggml_allocator * alloc = (struct ggml_allocator *)malloc(sizeof(struct ggml_allocator) /* + n_free_blocks * sizeof(struct free_block) */);
*alloc = (struct ggml_allocator){
/*.data = */ MEASURE_BASE_ADDR,
/*.size = */ MEASURE_MAX_SIZE,
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ true,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ = {0},
#endif
};
ggml_allocator_reset(alloc);
return alloc;
}
void ggml_allocator_free(struct ggml_allocator * alloc) {
free(alloc);
}
bool ggml_allocator_is_measure(struct ggml_allocator * alloc) {
return alloc->measure;
}
//////////// compute graph allocator
static bool ggml_is_view(struct ggml_tensor * t) {
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
}
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (a->ne[i] != b->ne[i]) {
return false;
}
if (a->nb[i] != b->nb[i]) {
return false;
}
}
return true;
}
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
switch (t->op) {
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
return t->src[0];
case GGML_OP_CPY:
return t->src[1];
default:
return NULL;
}
}
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
struct ggml_tensor * parent = t;
do {
parent = get_view_parent(parent);
} while (ggml_is_view(parent));
return parent;
}
static void allocate_node(struct ggml_allocator * alloc, struct ggml_tensor * node) {
if (node->data == NULL) {
if (ggml_is_view(node)) {
size_t offset;
switch(node->op) {
case GGML_OP_VIEW:
memcpy(&offset, node->op_params, sizeof(size_t));
node->data = (char *) node->src[0]->data + offset;
break;
case GGML_OP_PERMUTE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
node->data = node->src[0]->data;
break;
case GGML_OP_CPY:
node->data = node->src[1]->data;
break;
default:
GGML_ASSERT(!"unknown view op");
break;
}
} else {
// see if we can reuse a parent's buffer (inplace)
for (int i = 0; i < GGML_MAX_SRC; i++) {
struct ggml_tensor * parent = node->src[i];
if (parent == NULL) {
break;
}
// TODO: make a list of operations that can be safely made inplace
if (parent->data != NULL && parent->n_children == 1 && parent->n_views == 0 && ggml_are_same_layout(node, parent) && node->op != GGML_OP_MUL_MAT) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = get_view_source(parent);
if (view_src->n_views == 1 && view_src->n_children == 0 && view_src->data == parent->data) {
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
// the parent's data that it will need later (same layout requirement). the problem is that then
// we cannot free the tensor because the original address of the allocation is lost.
// adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
// for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
node->data = parent->data;
return;
}
}
else {
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
node->data = parent->data;
}
return;
}
}
ggml_allocator_alloc_tensor(alloc, node);
}
}
}
static size_t ggml_allocator_alloc_graph_tensors_n(
struct ggml_allocator * alloc,
struct ggml_cgraph ** graphs, int n_graphs,
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
// reset counters
for (int g = 0; g < n_graphs; g++) {
struct ggml_cgraph * gf = graphs[g];
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
node->n_children = 0;
node->n_views = 0;
}
for (int i = 0; i < gf->n_leafs; i++) {
struct ggml_tensor * leaf = gf->leafs[i];
leaf->n_children = 0;
leaf->n_views = 0;
}
}
// count number of children and views
for (int g = 0; g < n_graphs; g++) {
struct ggml_cgraph * gf = graphs[g];
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
if (ggml_is_view(node)) {
struct ggml_tensor * view_src = get_view_source(node);
view_src->n_views += 1;
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
parent->n_children += 1;
}
}
}
// allocate tensors
for (int g = 0; g < n_graphs; g++) {
struct ggml_cgraph * gf = graphs[g];
AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
// graph inputs are allocated first to ensure that they are never overwritten
if (inputs != NULL && inputs[g] != NULL) {
for (int i = 0; inputs[g][i] != NULL; i++) {
struct ggml_tensor * input = inputs[g][i];
AT_PRINTF("input: %s\n", input->name);
allocate_node(alloc, input);
}
}
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
// allocate parents (leafs)
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
allocate_node(alloc, parent);
}
// allocate node
allocate_node(alloc, node);
AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
AT_PRINTF("%s", parent->name);
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
AT_PRINTF(", ");
}
}
AT_PRINTF("\n");
// update parents
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
parent->n_children -= 1;
//AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
if (parent->n_children == 0 && parent->n_views == 0) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = get_view_source(parent);
view_src->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
if (view_src->n_views == 0 && view_src->n_children == 0 && view_src->data != node->data) {
ggml_allocator_free_tensor(alloc, view_src);
}
}
else {
if (parent->data != node->data) {
ggml_allocator_free_tensor(alloc, parent);
}
}
}
}
AT_PRINTF("\n");
}
// free graph outputs here that wouldn't be freed otherwise because they have no children
if (outputs != NULL && outputs[g] != NULL) {
for (int i = 0; outputs[g][i] != NULL; i++) {
struct ggml_tensor * output = outputs[g][i];
AT_PRINTF("output: %s\n", output->name);
ggml_allocator_free_tensor(alloc, output);
}
}
}
return alloc->max_size;
}
size_t ggml_allocator_alloc_graph_tensors(struct ggml_allocator * alloc, struct ggml_cgraph * graph) {
return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
}

21
ggml-alloc.h Normal file
View file

@ -0,0 +1,21 @@
#pragma once
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
GGML_API struct ggml_allocator * ggml_allocator_new(void * data, size_t size, size_t alignment);
GGML_API struct ggml_allocator * ggml_allocator_new_measure(size_t alignment);
GGML_API void ggml_allocator_free(struct ggml_allocator * alloc);
GGML_API bool ggml_allocator_is_measure(struct ggml_allocator * alloc);
GGML_API void ggml_allocator_reset(struct ggml_allocator * alloc);
GGML_API void ggml_allocator_alloc_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_allocator_alloc_graph_tensors(struct ggml_allocator * alloc, struct ggml_cgraph * graph);
#ifdef __cplusplus
}
#endif

14
ggml.c
View file

@ -4610,6 +4610,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data, /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
/*.name =*/ { 0 }, /*.name =*/ { 0 },
/*.extra =*/ NULL, /*.extra =*/ NULL,
/*.n_children =*/ 0,
/*.n_views =*/ 0,
/*.padding =*/ { 0 }, /*.padding =*/ { 0 },
}; };
@ -6741,6 +6743,18 @@ struct ggml_tensor * ggml_rope_inplace(
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true); return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
} }
struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode,
int n_ctx,
float freq_base,
float freq_scale) {
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false);
}
struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,

19
ggml.h
View file

@ -451,7 +451,11 @@ extern "C" {
void * extra; // extra things e.g. for ggml-cuda.cu void * extra; // extra things e.g. for ggml-cuda.cu
char padding[4]; // temp - used by allocator
int n_children;
int n_views;
char padding[16];
}; };
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@ -1170,7 +1174,18 @@ extern "C" {
int mode, int mode,
int n_ctx); int n_ctx);
// custom RoPE, in-place, returns view(a) // custom RoPE
GGML_API struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode,
int n_ctx,
float freq_base,
float freq_scale);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_custom_inplace( GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,

198
llama.cpp
View file

@ -56,7 +56,13 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_CLBLAST) && !defined(GGML_USE_METAL)
# include "ggml-alloc.h"
# define LLAMA_USE_ALLOCATOR
#else
# define LLAMA_USE_SCRATCH # define LLAMA_USE_SCRATCH
#endif
#define LLAMA_MAX_SCRATCH_BUFFERS 16 #define LLAMA_MAX_SCRATCH_BUFFERS 16
// available llama models // available llama models
@ -371,7 +377,17 @@ struct llama_context {
// memory buffers used to evaluate the model // memory buffers used to evaluate the model
// TODO: move in llama_state // TODO: move in llama_state
llama_ctx_buffer buf_compute; llama_ctx_buffer buf_compute;
#ifdef LLAMA_USE_ALLOCATOR
llama_ctx_buffer buf_alloc;
ggml_allocator * alloc = NULL;
#endif
#ifdef LLAMA_USE_SCRATCH
llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
int buf_last = 0;
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
#endif
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
ggml_metal_context * ctx_metal = NULL; ggml_metal_context * ctx_metal = NULL;
@ -381,9 +397,6 @@ struct llama_context {
ggml_mpi_context * ctx_mpi = NULL; ggml_mpi_context * ctx_mpi = NULL;
#endif #endif
int buf_last = 0;
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
void use_buf(struct ggml_context * ctx, int i) { void use_buf(struct ggml_context * ctx, int i) {
#if defined(LLAMA_USE_SCRATCH) #if defined(LLAMA_USE_SCRATCH)
size_t last_size = 0; size_t last_size = 0;
@ -1360,32 +1373,15 @@ static bool llama_model_load(
} }
} }
// evaluate the transformer static struct ggml_cgraph * llama_build_graph(
//
// - lctx: llama context
// - tokens: new batch of tokens to process
// - embd embeddings input
// - n_tokens number of tokens
// - n_past: the context size so far
// - n_threads: number of threads to use
//
static bool llama_eval_internal(
llama_context & lctx, llama_context & lctx,
const llama_token * tokens, const llama_token * tokens,
const float * embd, const float * embd,
int n_tokens, int n_tokens,
int n_past, int n_past) {
int n_threads,
const char * cgraph_fname) {
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
#ifdef GGML_USE_MPI
ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
#endif
const int64_t t_start_us = ggml_time_us();
const int N = n_tokens; const int N = n_tokens;
const auto & model = lctx.model; const auto & model = lctx.model;
@ -1401,10 +1397,9 @@ static bool llama_eval_internal(
const int64_t n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_vocab = hparams.n_vocab; //const int64_t n_vocab = hparams.n_vocab;
const int64_t n_embd_gqa = hparams.n_embd_gqa(); const int64_t n_embd_gqa = hparams.n_embd_gqa();
LLAMA_ASSERT(n_embd_head == hparams.n_rot); LLAMA_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base; const float freq_base = hparams.rope_freq_base;
@ -1413,29 +1408,40 @@ static bool llama_eval_internal(
const int n_gpu_layers = model.n_gpu_layers; const int n_gpu_layers = model.n_gpu_layers;
auto & mem_per_token = lctx.mem_per_token;
auto & buf_compute = lctx.buf_compute; auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size, /*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.addr, /*.mem_buffer =*/ buf_compute.addr,
/*.no_alloc =*/ false, /*.no_alloc =*/ false,
}; };
#ifdef LLAMA_USE_ALLOCATOR
# define ggml_rope_custom_inplace ggml_rope_custom
# define ggml_scale_inplace ggml_scale
# define ggml_diag_mask_inf_inplace ggml_diag_mask_inf
# define ggml_soft_max_inplace ggml_soft_max
params.no_alloc = true;
#endif
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0); ggml_cgraph * gf = ggml_new_graph(ctx0);
// for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
if (tokens) { if (tokens) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
#ifdef LLAMA_USE_ALLOCATOR
ggml_allocator_alloc_tensor(lctx.alloc, inp_tokens);
if (!ggml_allocator_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
}
#else
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
#endif
ggml_set_name(inp_tokens, "inp_tokens"); ggml_set_name(inp_tokens, "inp_tokens");
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
@ -1472,6 +1478,17 @@ static bool llama_eval_internal(
} }
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
#ifdef LLAMA_USE_ALLOCATOR
ggml_allocator_alloc_tensor(lctx.alloc, KQ_scale);
if (!ggml_allocator_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
#else
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
#endif
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); ggml_format_name(inpL, "layer_inp_%d", il);
@ -1567,9 +1584,6 @@ static bool llama_eval_internal(
ggml_set_name(KQ, "KQ"); ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled = KQ / sqrt(n_embd_head)
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_scaled shape [n_past + N, N, n_head, 1] // KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled); offload_func_kq(KQ_scaled);
@ -1700,6 +1714,9 @@ static bool llama_eval_internal(
ggml_set_name(cur, "result_norm"); ggml_set_name(cur, "result_norm");
embeddings = cur; embeddings = cur;
#ifdef LLAMA_USE_ALLOCATOR
// TODO: ensure that embeddings is not freed
#endif
} }
// lm_head // lm_head
@ -1711,11 +1728,84 @@ static bool llama_eval_internal(
// logits -> probs // logits -> probs
//cur = ggml_soft_max_inplace(ctx0, cur); //cur = ggml_soft_max_inplace(ctx0, cur);
// run the computation
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
// outputs: cur, embeddings
ggml_free(ctx0);
return gf;
#ifdef LLAMA_USE_ALLOCATOR
# undef ggml_rope_custom
# undef ggml_scale
# undef ggml_diag_mask_inf
# undef ggml_soft_max
#endif
}
// evaluate the transformer
//
// - lctx: llama context
// - tokens: new batch of tokens to process
// - embd embeddings input
// - n_tokens number of tokens
// - n_past: the context size so far
// - n_threads: number of threads to use
//
static bool llama_eval_internal(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
int n_tokens,
int n_past,
int n_threads,
const char * cgraph_fname) {
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
const int64_t t_start_us = ggml_time_us();
#ifdef GGML_USE_MPI
ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
#endif
const int N = n_tokens;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & kv_self = lctx.kv_self;
LLAMA_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd;
//const int64_t n_layer = hparams.n_layer;
//const int64_t n_ctx = hparams.n_ctx;
//const int64_t n_head = hparams.n_head;
//const int64_t n_head_kv = hparams.n_head_kv;
//const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_vocab = hparams.n_vocab;
//const int64_t n_embd_gqa = hparams.n_embd_gqa();
//auto & mem_per_token = lctx.mem_per_token;
#ifdef LLAMA_USE_ALLOCATOR
ggml_allocator_reset(lctx.alloc);
#endif
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
#ifdef LLAMA_USE_ALLOCATOR
size_t sz = ggml_allocator_alloc_graph_tensors(lctx.alloc, gf);
//fprintf(stderr, "%s: compute buffer size: %.3f MB\n", __func__, sz / 1024.0 / 1024.0);
#endif
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs); // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
// for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
#if GGML_USE_MPI #if GGML_USE_MPI
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
#endif #endif
@ -1760,6 +1850,10 @@ static bool llama_eval_internal(
lctx.kv_self.n = n_past + N; lctx.kv_self.n = n_past + N;
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embeddings = NULL;
LLAMA_ASSERT(strcmp(res->name, "result_output") == 0);
//LLAMA_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
if (cgraph_fname) { if (cgraph_fname) {
ggml_graph_export(gf, cgraph_fname); ggml_graph_export(gf, cgraph_fname);
@ -1798,9 +1892,9 @@ static bool llama_eval_internal(
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
} }
if (mem_per_token == 0) { //if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N; // mem_per_token = ggml_used_mem(ctx0)/N;
} //}
#if 0 #if 0
printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
@ -1811,8 +1905,6 @@ static bool llama_eval_internal(
n_past, N); n_past, N);
#endif #endif
ggml_free(ctx0);
// measure the performance only for the single-token evals // measure the performance only for the single-token evals
if (N == 1) { if (N == 1) {
lctx.t_eval_us += ggml_time_us() - t_start_us; lctx.t_eval_us += ggml_time_us() - t_start_us;
@ -3178,10 +3270,38 @@ struct llama_context * llama_new_context_with_model(
ctx->embedding.resize(hparams.n_embd); ctx->embedding.resize(hparams.n_embd);
} }
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); #ifdef LLAMA_USE_ALLOCATOR
ctx->buf_compute.resize(ggml_tensor_overhead() * 3072 + ggml_graph_overhead());
// measure memory requirements for worst-case graph
ctx->alloc = ggml_allocator_new_measure(32);
// build worst-case graph
int n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
int n_past = hparams.n_ctx - n_tokens;
std::vector<llama_token> tokens(n_tokens, llama_token_bos());
ggml_cgraph * gf = llama_build_graph(*ctx, tokens.data(), NULL, n_tokens, n_past);
size_t size = ggml_allocator_alloc_graph_tensors(ctx->alloc, gf);
fprintf(stderr, "%s: worst-case graph size = %7.2f MB\n", __func__, size / 1024.0 / 1024.0);
fprintf(stderr, "%s: compute buffer total size: %7.2f MB\n", __func__, (ctx->buf_compute.size + size) / 1024.0 / 1024.0);
size_t prev_req = MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) + MEM_REQ_SCRATCH1().at(ctx->model.type) + MEM_REQ_EVAL().at(ctx->model.type);
fprintf(stderr, "%s: equivalent with scratch buffer: %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
// recreate allocator with exact memory requirements
ggml_allocator_free(ctx->alloc);
ctx->buf_alloc.resize(size);
ctx->alloc = ggml_allocator_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, 32);
#else
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
#endif
#ifdef LLAMA_USE_SCRATCH
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
#endif
} }
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL