Merge branch 'master' into concedo_experimental
# Conflicts: # CMakeLists.txt # Makefile # README.md
This commit is contained in:
commit
278427d9a4
11 changed files with 172 additions and 82 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -31,6 +31,7 @@ out/
|
|||
/perplexity
|
||||
/embedding
|
||||
/train-text-from-scratch
|
||||
/simple
|
||||
/benchmark-matmult
|
||||
/vdot
|
||||
/server
|
||||
|
|
|
@ -38,6 +38,7 @@ else()
|
|||
add_subdirectory(benchmark)
|
||||
add_subdirectory(baby-llama)
|
||||
add_subdirectory(train-text-from-scratch)
|
||||
add_subdirectory(simple)
|
||||
if (LLAMA_METAL)
|
||||
add_subdirectory(metal)
|
||||
endif()
|
||||
|
|
|
@ -106,9 +106,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
}
|
||||
|
||||
if (arg == "-s" || arg == "--seed") {
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
|
||||
#endif
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
|
|
|
@ -354,7 +354,7 @@ int main(int argc, char ** argv) {
|
|||
if ((int)embd.size() > max_embd_size) {
|
||||
auto skipped_tokens = embd.size() - max_embd_size;
|
||||
console_set_color(con_st, CONSOLE_COLOR_ERROR);
|
||||
printf("<<input too long: skipped %" PRIu64 " token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
|
||||
printf("<<input too long: skipped %zu token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
|
||||
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
|
||||
fflush(stdout);
|
||||
embd.resize(max_embd_size);
|
||||
|
|
58
ggml-cuda.cu
58
ggml-cuda.cu
|
@ -13,6 +13,10 @@
|
|||
#include "ggml-cuda.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
|
||||
#define CUDA_CHECK(err) \
|
||||
|
@ -1463,19 +1467,13 @@ static void * g_scratch_buffer = nullptr;
|
|||
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
|
||||
static size_t g_scratch_offset = 0;
|
||||
|
||||
#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
|
||||
#define GGML_CUDA_MAX_EVENTS 64
|
||||
|
||||
static int g_device_count = -1;
|
||||
static int g_main_device = 0;
|
||||
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
|
||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||
|
||||
static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||
static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
|
||||
static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
|
||||
|
||||
void ggml_init_cublas() {
|
||||
static bool initialized = false;
|
||||
|
@ -1499,15 +1497,8 @@ void ggml_init_cublas() {
|
|||
for (int id = 0; id < g_device_count; ++id) {
|
||||
CUDA_CHECK(cudaSetDevice(id));
|
||||
|
||||
// create streams
|
||||
for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
|
||||
}
|
||||
// create events
|
||||
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
|
||||
}
|
||||
// create main stream
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
|
||||
|
||||
// create cublas handle
|
||||
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
||||
|
@ -1974,6 +1965,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
|
||||
// if multiple GPUs are used they need to wait for the main GPU to finish
|
||||
if (split && g_device_count > 1) {
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
if (!split && id != g_main_device) {
|
||||
continue;
|
||||
|
@ -2072,9 +2069,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
const int64_t i11 = i13*ne12 + i12;
|
||||
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
|
||||
|
||||
// for split tensors the data begins at i0 == i0_offset_low
|
||||
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
|
||||
|
@ -2102,14 +2097,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
|
||||
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
|
||||
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
|
||||
if (id != g_main_device) {
|
||||
GGML_ASSERT(!flatten_rows);
|
||||
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
|
||||
src1_ddf_i_source += i11*src1_stride;
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
|
||||
cudaMemcpyDeviceToDevice, cudaStream_main));
|
||||
}
|
||||
} else if (src1_on_device && !src1_is_contiguous) {
|
||||
GGML_ASSERT(!split);
|
||||
|
@ -2118,7 +2113,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
|
||||
|
||||
if (!src0_on_device || !src0_is_contiguous) {
|
||||
if (src0_is_f32) {
|
||||
|
@ -2134,9 +2128,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
// wait with main stream until src1 memcpy is done
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
|
||||
|
||||
// do the computation
|
||||
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
|
||||
|
||||
|
@ -2174,8 +2165,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
|
||||
// wait until each device is finished, then free their buffers
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(id));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
if (src0_asq[id] > 0) {
|
||||
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
|
||||
}
|
||||
|
@ -2241,7 +2237,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
const int64_t ne02 = src0->ne[2];
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
|
@ -2253,8 +2249,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
|
@ -2272,7 +2266,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
const int64_t nb02 = src0->nb[2];
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
|
||||
|
||||
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
|
@ -2287,8 +2281,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
const int channel_stride_x = nb02 / sizeof(half);
|
||||
|
||||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -2344,7 +2336,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
|||
const int64_t nb12 = src1->nb[2];
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
|
||||
|
||||
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
|
@ -2362,8 +2354,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
|
|
|
@ -41,12 +41,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx);
|
|||
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
|
||||
// - the mapping is used during computation to determine the arguments of the compute kernels
|
||||
// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
|
||||
// - max_size specifies the maximum size of a tensor and is used to create shared views such
|
||||
// that it is guaranteed that the tensor will fit in at least one of the views
|
||||
//
|
||||
bool ggml_metal_add_buffer(
|
||||
struct ggml_metal_context * ctx,
|
||||
const char * name,
|
||||
void * data,
|
||||
size_t size);
|
||||
size_t size,
|
||||
size_t max_size);
|
||||
|
||||
// set data from host memory into the device
|
||||
void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
||||
|
|
109
ggml-metal.m
109
ggml-metal.m
|
@ -183,6 +183,14 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|||
#undef GGML_METAL_ADD_KERNEL
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||
if (ctx->device.maxTransferRate != 0) {
|
||||
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
@ -199,10 +207,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
||||
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
||||
|
||||
const int64_t tsize = ggml_nbytes(t);
|
||||
|
||||
// find the view that contains the tensor fully
|
||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
||||
|
||||
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
||||
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
||||
*offs = (size_t) ioffs;
|
||||
|
||||
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
||||
|
@ -220,7 +231,8 @@ bool ggml_metal_add_buffer(
|
|||
struct ggml_metal_context * ctx,
|
||||
const char * name,
|
||||
void * data,
|
||||
size_t size) {
|
||||
size_t size,
|
||||
size_t max_size) {
|
||||
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
||||
fprintf(stderr, "%s: too many buffers\n", __func__);
|
||||
return false;
|
||||
|
@ -237,31 +249,69 @@ bool ggml_metal_add_buffer(
|
|||
}
|
||||
}
|
||||
|
||||
size_t page_size = getpagesize();
|
||||
size_t aligned_size = size;
|
||||
if ((aligned_size % page_size) != 0) {
|
||||
aligned_size += (page_size - (aligned_size % page_size));
|
||||
const size_t size_page = getpagesize();
|
||||
|
||||
size_t size_aligned = size;
|
||||
if ((size_aligned % size_page) != 0) {
|
||||
size_aligned += (size_page - (size_aligned % size_page));
|
||||
}
|
||||
|
||||
// the buffer fits into the max buffer size allowed by the device
|
||||
if (size_aligned <= ctx->device.maxBufferLength) {
|
||||
ctx->buffers[ctx->n_buffers].name = name;
|
||||
ctx->buffers[ctx->n_buffers].data = data;
|
||||
ctx->buffers[ctx->n_buffers].size = size;
|
||||
|
||||
if (ctx->device.maxBufferLength < aligned_size) {
|
||||
fprintf(stderr, "%s: buffer '%s' size %zu is larger than buffer maximum of %zu\n", __func__, name, aligned_size, ctx->device.maxBufferLength);
|
||||
return false;
|
||||
}
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil];
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
|
||||
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
|
||||
++ctx->n_buffers;
|
||||
} else {
|
||||
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
|
||||
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
||||
// one of the views
|
||||
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
||||
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
|
||||
const size_t size_view = ctx->device.maxBufferLength;
|
||||
|
||||
for (size_t i = 0; i < size; i += size_step) {
|
||||
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
||||
|
||||
ctx->buffers[ctx->n_buffers].name = name;
|
||||
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
||||
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
||||
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
||||
if (i + size_step < size) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
++ctx->n_buffers;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, ", (%8.2f / %8.2f)",
|
||||
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
|
||||
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
||||
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
|
||||
} else {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -765,18 +815,23 @@ void ggml_metal_graph_compute(
|
|||
} break;
|
||||
case GGML_OP_ALIBI:
|
||||
{
|
||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||
const int n_past = ((int32_t *) src1->data)[0];
|
||||
const int n_head = ((int32_t *) src1->data)[1];
|
||||
const float max_bias = ((float *) src1->data)[2];
|
||||
if (__builtin_popcount(n_head) != 1) {
|
||||
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
||||
}
|
||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
if (encoder == nil) {
|
||||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||
|
||||
const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
|
||||
const int n_head = ((int32_t *) src1->data)[1];
|
||||
const float max_bias = ((float *) src1->data)[2];
|
||||
|
||||
if (__builtin_popcount(n_head) != 1) {
|
||||
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
||||
}
|
||||
|
||||
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
@ -904,4 +959,14 @@ void ggml_metal_graph_compute(
|
|||
dispatch_barrier_sync(queue, ^{});
|
||||
|
||||
[command_buffers[n_cb - 1] waitUntilCompleted];
|
||||
|
||||
// check status of command buffers
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
for (int i = 0; i < n_cb; i++) {
|
||||
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
|
||||
#include "ggml.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#define CL_DMMV_BLOCK_SIZE 32
|
||||
|
||||
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
||||
|
|
24
ggml.c
24
ggml.c
|
@ -4154,14 +4154,34 @@ void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
|
|||
ctx->no_alloc = no_alloc;
|
||||
}
|
||||
|
||||
void * ggml_get_mem_buffer(struct ggml_context * ctx) {
|
||||
void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
|
||||
return ctx->mem_buffer;
|
||||
}
|
||||
|
||||
size_t ggml_get_mem_size(struct ggml_context * ctx) {
|
||||
size_t ggml_get_mem_size(const struct ggml_context * ctx) {
|
||||
return ctx->mem_size;
|
||||
}
|
||||
|
||||
size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
|
||||
size_t max_size = 0;
|
||||
|
||||
struct ggml_object * obj = ctx->objects_begin;
|
||||
|
||||
while (obj != NULL) {
|
||||
struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
|
||||
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
|
||||
if (max_size < size) {
|
||||
max_size = size;
|
||||
}
|
||||
|
||||
obj = obj->next;
|
||||
}
|
||||
|
||||
return max_size;
|
||||
}
|
||||
|
||||
// IMPORTANT:
|
||||
// when creating "opt" tensors, always save and load the scratch buffer
|
||||
// this is an error prone process, but it is necessary to support inplace
|
||||
|
|
5
ggml.h
5
ggml.h
|
@ -500,8 +500,9 @@ extern "C" {
|
|||
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
|
||||
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
||||
|
||||
GGML_API void * ggml_get_mem_buffer(struct ggml_context * ctx);
|
||||
GGML_API size_t ggml_get_mem_size (struct ggml_context * ctx);
|
||||
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
|
||||
GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx);
|
||||
GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_new_tensor(
|
||||
struct ggml_context * ctx,
|
||||
|
|
26
llama.cpp
26
llama.cpp
|
@ -886,6 +886,7 @@ static bool kv_cache_init(
|
|||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||
cache.n = 0;
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = cache.buf.size;
|
||||
|
@ -904,6 +905,7 @@ static bool kv_cache_init(
|
|||
ggml_set_name(cache.k, "cache_k");
|
||||
ggml_set_name(cache.v, "cache_v");
|
||||
|
||||
(void) n_gpu_layers;
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (n_gpu_layers > n_layer + 1) {
|
||||
ggml_cuda_assign_buffers_no_scratch(cache.v);
|
||||
|
@ -1253,7 +1255,7 @@ static void llama_model_load_internal(
|
|||
vram_scratch = n_batch * MB;
|
||||
ggml_cuda_set_scratch_size(vram_scratch);
|
||||
if (n_gpu_layers > 0) {
|
||||
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
|
||||
fprintf(stderr, "%s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n",
|
||||
__func__, vram_scratch / MB);
|
||||
}
|
||||
}
|
||||
|
@ -2694,16 +2696,21 @@ struct llama_context * llama_init_from_file(
|
|||
// this allocates all Metal resources and memory buffers
|
||||
ctx->ctx_metal = ggml_metal_init();
|
||||
|
||||
void *data_ptr = NULL;
|
||||
void * data_ptr = NULL;
|
||||
size_t data_size = 0;
|
||||
|
||||
if (params.use_mmap) {
|
||||
data_ptr = ctx->model.mapping->addr;
|
||||
data_size= ctx->model.mapping->size;
|
||||
data_size = ctx->model.mapping->size;
|
||||
} else {
|
||||
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
|
||||
data_size= ggml_get_mem_size(ctx->model.ctx);
|
||||
data_size = ggml_get_mem_size (ctx->model.ctx);
|
||||
}
|
||||
|
||||
const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx);
|
||||
|
||||
printf("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0);
|
||||
|
||||
#define LLAMA_METAL_CHECK_BUF(result) \
|
||||
if (!(result)) { \
|
||||
fprintf(stderr, "%s: failed to add buffer\n", __func__); \
|
||||
|
@ -2711,12 +2718,13 @@ struct llama_context * llama_init_from_file(
|
|||
return NULL; \
|
||||
}
|
||||
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size));
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size));
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size));
|
||||
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size));
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size));
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size));
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0));
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size, 0));
|
||||
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size, 0));
|
||||
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size, 0));
|
||||
#undef LLAMA_METAL_CHECK_BUF
|
||||
}
|
||||
#endif
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue