diff --git a/.gitignore b/.gitignore
index e68fd724a..e7bfd52e3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -34,6 +34,7 @@ models/*
/perplexity
/embedding
/train-text-from-scratch
+/simple
/benchmark-matmult
/vdot
/server
diff --git a/Makefile b/Makefile
index e77a58a71..59efb1ef7 100644
--- a/Makefile
+++ b/Makefile
@@ -144,11 +144,7 @@ endif # LLAMA_NO_ACCELERATE
ifdef LLAMA_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
- ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),)
- LDFLAGS += -lopenblas -lcblas
- else
- LDFLAGS += -lopenblas
- endif
+ LDFLAGS += -lopenblas
endif # LLAMA_OPENBLAS
ifdef LLAMA_BLIS
@@ -298,9 +294,6 @@ main: examples/main/main.cpp build-info.h ggml.
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
- @echo
- @echo '==== Run ./simple -h for help. ===='
- @echo
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
diff --git a/README.md b/README.md
index b9759b00b..7defb7584 100644
--- a/README.md
+++ b/README.md
@@ -336,7 +336,6 @@ Building the program with BLAS support may lead to some performance improvements
cmake .. -DLLAMA_CUBLAS=ON
cmake --build . --config Release
```
- Note: Because llama.cpp uses multiple CUDA streams for matrix multiplication results [are not guaranteed to be reproducible](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility). If you need reproducibility, set `GGML_CUDA_MAX_STREAMS` in the file `ggml-cuda.cu` to 1.
The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index de005f3e3..cf9c4a223 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -38,6 +38,7 @@ else()
add_subdirectory(benchmark)
add_subdirectory(baby-llama)
add_subdirectory(train-text-from-scratch)
+ add_subdirectory(simple)
if (LLAMA_METAL)
add_subdirectory(metal)
endif()
diff --git a/examples/common.cpp b/examples/common.cpp
index 055383bef..fed24e027 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -106,9 +106,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}
if (arg == "-s" || arg == "--seed") {
-#if defined(GGML_USE_CUBLAS)
- fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
-#endif
if (++i >= argc) {
invalid_param = true;
break;
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index a051fcbc5..941312f9c 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -354,7 +354,7 @@ int main(int argc, char ** argv) {
if ((int)embd.size() > max_embd_size) {
auto skipped_tokens = embd.size() - max_embd_size;
console_set_color(con_st, CONSOLE_COLOR_ERROR);
- printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
+ printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
fflush(stdout);
embd.resize(max_embd_size);
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 2aea3480b..662108201 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -68,6 +68,10 @@
#include "ggml-cuda.h"
#include "ggml.h"
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#define CUDA_CHECK(err) \
@@ -1518,19 +1522,13 @@ static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
static size_t g_scratch_offset = 0;
-#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
-#define GGML_CUDA_MAX_EVENTS 64
-
static int g_device_count = -1;
static int g_main_device = 0;
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
-
-static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
void ggml_init_cublas() {
static bool initialized = false;
@@ -1554,15 +1552,8 @@ void ggml_init_cublas() {
for (int id = 0; id < g_device_count; ++id) {
CUDA_CHECK(cudaSetDevice(id));
- // create streams
- for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
- CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
- CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
- }
- // create events
- for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
- CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
- }
+ // create main stream
+ CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
// create cublas handle
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -2029,6 +2020,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+ // if multiple GPUs are used they need to wait for the main GPU to finish
+ if (split && g_device_count > 1) {
+ CUDA_CHECK(cudaSetDevice(g_main_device));
+ CUDA_CHECK(cudaDeviceSynchronize());
+ }
+
for (int id = 0; id < g_device_count; ++id) {
if (!split && id != g_main_device) {
continue;
@@ -2127,9 +2124,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
}
const int64_t i11 = i13*ne12 + i12;
- cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
- cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
- cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id];
// for split tensors the data begins at i0 == i0_offset_low
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
@@ -2157,14 +2152,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
if (src1->backend == GGML_BACKEND_CPU) {
GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
int64_t nrows1 = flatten_rows ? nrows0 : ne11;
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
if (id != g_main_device) {
GGML_ASSERT(!flatten_rows);
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
src1_ddf_i_source += i11*src1_stride;
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
- cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
+ cudaMemcpyDeviceToDevice, cudaStream_main));
}
} else if (src1_on_device && !src1_is_contiguous) {
GGML_ASSERT(!split);
@@ -2173,7 +2168,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
GGML_ASSERT(false);
}
}
- CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
if (!src0_on_device || !src0_is_contiguous) {
if (src0_is_f32) {
@@ -2189,9 +2183,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
CUDA_CHECK(cudaGetLastError());
}
- // wait with main stream until src1 memcpy is done
- CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
-
// do the computation
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
@@ -2229,8 +2220,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
// wait until each device is finished, then free their buffers
for (int id = 0; id < g_device_count; ++id) {
+ if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
+ continue;
+ }
+
CUDA_CHECK(cudaSetDevice(id));
CUDA_CHECK(cudaDeviceSynchronize());
+
if (src0_asq[id] > 0) {
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
}
@@ -2296,7 +2292,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
const int64_t ne02 = src0->ne[2];
CUDA_CHECK(cudaSetDevice(g_main_device));
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2308,8 +2304,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
-
- CUDA_CHECK(cudaDeviceSynchronize());
}
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -2327,7 +2321,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int64_t nb02 = src0->nb[2];
CUDA_CHECK(cudaSetDevice(g_main_device));
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2342,8 +2336,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int channel_stride_x = nb02 / sizeof(half);
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
-
- CUDA_CHECK(cudaDeviceSynchronize());
}
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2399,7 +2391,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
const int64_t nb12 = src1->nb[2];
CUDA_CHECK(cudaSetDevice(g_main_device));
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -2417,8 +2409,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
GGML_ASSERT(false);
}
- CUDA_CHECK(cudaDeviceSynchronize());
-
(void) dst;
}
diff --git a/ggml-metal.m b/ggml-metal.m
index 0e9b56aa3..07da62a25 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -57,6 +57,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q5_k);
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
GGML_METAL_DECL_KERNEL(rms_norm);
+ GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
@@ -66,8 +67,10 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_DECL_KERNEL(rope);
+ GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+ GGML_METAL_DECL_KERNEL(cpy_f16_f16);
#undef GGML_METAL_DECL_KERNEL
};
@@ -162,6 +165,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
GGML_METAL_ADD_KERNEL(rms_norm);
+ GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
@@ -171,8 +175,10 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_ADD_KERNEL(rope);
+ GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+ GGML_METAL_ADD_KERNEL(cpy_f16_f16);
#undef GGML_METAL_ADD_KERNEL
}
@@ -250,10 +256,10 @@ bool ggml_metal_add_buffer(
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
return false;
- } else {
- fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
}
+ fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
+
++ctx->n_buffers;
}
@@ -735,6 +741,70 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
+ case GGML_OP_NORM:
+ {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+
+ const float eps = 1e-5f;
+
+ const int nth = 256;
+
+ [encoder setComputePipelineState:ctx->pipeline_norm];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ALIBI:
+ {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
+
+ const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
+ const int n_head = ((int32_t *) src1->data)[1];
+ const float max_bias = ((float *) src1->data)[2];
+
+ if (__builtin_popcount(n_head) != 1) {
+ GGML_ASSERT(false && "only power-of-two n_head implemented");
+ }
+
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+
+ [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
+ const int nth = 32;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
@@ -788,6 +858,14 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
};
} break;
+ case GGML_TYPE_F16:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
+ case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+ default: GGML_ASSERT(false && "not implemented");
+ };
+ } break;
default: GGML_ASSERT(false && "not implemented");
}
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 09e12a879..d1e49222d 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(
(device float *) ((device char *) dst + i*nb1), ne00);
}
+kernel void kernel_norm(
+ device const void * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant float & eps,
+ threadgroup float * sum [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+ // MEAN
+ // parallel sum
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ sum[tpitg] += x[i00];
+ }
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ // broadcast
+ if (tpitg == 0) {
+ sum[0] /= ne00;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ const float mean = sum[0];
+
+ // recenter
+ device float * y = dst + tgpig*ne00;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = x[i00] - mean;
+ }
+
+ // VARIANCE
+ // parallel sum
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ sum[tpitg] += y[i00] * y[i00];
+ }
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ // broadcast
+ if (tpitg == 0) {
+ sum[0] /= ne00;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ const float variance = sum[0];
+
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = y[i00] * scale;
+ }
+}
+
+
kernel void kernel_rms_norm(
device const void * src0,
device float * dst,
@@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(
}
}
+kernel void kernel_alibi_f32(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant float & m0,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ float m_k = pow(m0, i2 + 1);
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+ }
+}
+
kernel void kernel_rope(
device const void * src0,
device float * dst,
@@ -540,6 +648,47 @@ kernel void kernel_rope(
}
}
+kernel void kernel_cpy_f16_f16(
+ device const half * src0,
+ device half * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp
index 1d4db96ee..95f4cec6d 100644
--- a/ggml-opencl.cpp
+++ b/ggml-opencl.cpp
@@ -15,6 +15,10 @@
#include "ggml.h"
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
#define CL_DMMV_BLOCK_SIZE 32
#define MULTILINE_QUOTE(...) #__VA_ARGS__
diff --git a/llama.cpp b/llama.cpp
index 81f047ed2..a2916b3e8 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -886,6 +886,7 @@ static bool kv_cache_init(
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
+ cache.n = 0;
struct ggml_init_params params;
params.mem_size = cache.buf.size;
@@ -904,6 +905,7 @@ static bool kv_cache_init(
ggml_set_name(cache.k, "cache_k");
ggml_set_name(cache.v, "cache_v");
+ (void) n_gpu_layers;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer + 1) {
ggml_cuda_assign_buffers_no_scratch(cache.v);
@@ -1253,7 +1255,7 @@ static void llama_model_load_internal(
vram_scratch = n_batch * MB;
ggml_cuda_set_scratch_size(vram_scratch);
if (n_gpu_layers > 0) {
- fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
+ fprintf(stderr, "%s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n",
__func__, vram_scratch / MB);
}
}