diff --git a/.gitignore b/.gitignore index e68fd724a..e7bfd52e3 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ models/* /perplexity /embedding /train-text-from-scratch +/simple /benchmark-matmult /vdot /server diff --git a/Makefile b/Makefile index e77a58a71..59efb1ef7 100644 --- a/Makefile +++ b/Makefile @@ -144,11 +144,7 @@ endif # LLAMA_NO_ACCELERATE ifdef LLAMA_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas - ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),) - LDFLAGS += -lopenblas -lcblas - else - LDFLAGS += -lopenblas - endif + LDFLAGS += -lopenblas endif # LLAMA_OPENBLAS ifdef LLAMA_BLIS @@ -298,9 +294,6 @@ main: examples/main/main.cpp build-info.h ggml. simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) - @echo - @echo '==== Run ./simple -h for help. ====' - @echo quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/README.md b/README.md index b9759b00b..7defb7584 100644 --- a/README.md +++ b/README.md @@ -336,7 +336,6 @@ Building the program with BLAS support may lead to some performance improvements cmake .. -DLLAMA_CUBLAS=ON cmake --build . --config Release ``` - Note: Because llama.cpp uses multiple CUDA streams for matrix multiplication results [are not guaranteed to be reproducible](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility). If you need reproducibility, set `GGML_CUDA_MAX_STREAMS` in the file `ggml-cuda.cu` to 1. The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index de005f3e3..cf9c4a223 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -38,6 +38,7 @@ else() add_subdirectory(benchmark) add_subdirectory(baby-llama) add_subdirectory(train-text-from-scratch) + add_subdirectory(simple) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/common.cpp b/examples/common.cpp index 055383bef..fed24e027 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -106,9 +106,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } if (arg == "-s" || arg == "--seed") { -#if defined(GGML_USE_CUBLAS) - fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n"); -#endif if (++i >= argc) { invalid_param = true; break; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a051fcbc5..941312f9c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -354,7 +354,7 @@ int main(int argc, char ** argv) { if ((int)embd.size() > max_embd_size) { auto skipped_tokens = embd.size() - max_embd_size; console_set_color(con_st, CONSOLE_COLOR_ERROR); - printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); console_set_color(con_st, CONSOLE_COLOR_DEFAULT); fflush(stdout); embd.resize(max_embd_size); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2aea3480b..662108201 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -68,6 +68,10 @@ #include "ggml-cuda.h" #include "ggml.h" +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); #define CUDA_CHECK(err) \ @@ -1518,19 +1522,13 @@ static void * g_scratch_buffer = nullptr; static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default static size_t g_scratch_offset = 0; -#define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication. -#define GGML_CUDA_MAX_EVENTS 64 - static int g_device_count = -1; static int g_main_device = 0; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; -static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr }; - -static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr }; -static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr }; +static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr }; void ggml_init_cublas() { static bool initialized = false; @@ -1554,15 +1552,8 @@ void ggml_init_cublas() { for (int id = 0; id < g_device_count; ++id) { CUDA_CHECK(cudaSetDevice(id)); - // create streams - for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) { - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking)); - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking)); - } - // create events - for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) { - CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming)); - } + // create main stream + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking)); // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); @@ -2029,6 +2020,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; + // if multiple GPUs are used they need to wait for the main GPU to finish + if (split && g_device_count > 1) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(cudaDeviceSynchronize()); + } + for (int id = 0; id < g_device_count; ++id) { if (!split && id != g_main_device) { continue; @@ -2127,9 +2124,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } const int64_t i11 = i13*ne12 + i12; - cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS]; - cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS]; - cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS]; + cudaStream_t cudaStream_main = g_cudaStreams_main[id]; // for split tensors the data begins at i0 == i0_offset_low char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; @@ -2157,14 +2152,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm if (src1->backend == GGML_BACKEND_CPU) { GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1)); int64_t nrows1 = flatten_rows ? nrows0 : ne11; - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main)); } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { if (id != g_main_device) { GGML_ASSERT(!flatten_rows); float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; src1_ddf_i_source += i11*src1_stride; CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float), - cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1)); + cudaMemcpyDeviceToDevice, cudaStream_main)); } } else if (src1_on_device && !src1_is_contiguous) { GGML_ASSERT(!split); @@ -2173,7 +2168,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm GGML_ASSERT(false); } } - CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1)); if (!src0_on_device || !src0_is_contiguous) { if (src0_is_f32) { @@ -2189,9 +2183,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm CUDA_CHECK(cudaGetLastError()); } - // wait with main stream until src1 memcpy is done - CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0)); - // do the computation op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main); @@ -2229,8 +2220,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // wait until each device is finished, then free their buffers for (int id = 0; id < g_device_count; ++id) { + if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) { + continue; + } + CUDA_CHECK(cudaSetDevice(id)); CUDA_CHECK(cudaDeviceSynchronize()); + if (src0_asq[id] > 0) { ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]); } @@ -2296,7 +2292,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr const int64_t ne02 = src0->ne[2]; CUDA_CHECK(cudaSetDevice(g_main_device)); - cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -2308,8 +2304,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); - - CUDA_CHECK(cudaDeviceSynchronize()); } void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ @@ -2327,7 +2321,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int64_t nb02 = src0->nb[2]; CUDA_CHECK(cudaSetDevice(g_main_device)); - cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -2342,8 +2336,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int channel_stride_x = nb02 / sizeof(half); ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); - - CUDA_CHECK(cudaDeviceSynchronize()); } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2399,7 +2391,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens const int64_t nb12 = src1->nb[2]; CUDA_CHECK(cudaSetDevice(g_main_device)); - cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; @@ -2417,8 +2409,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens GGML_ASSERT(false); } - CUDA_CHECK(cudaDeviceSynchronize()); - (void) dst; } diff --git a/ggml-metal.m b/ggml-metal.m index 0e9b56aa3..07da62a25 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -57,6 +57,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q5_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); + GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); @@ -66,8 +67,10 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); + GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); + GGML_METAL_DECL_KERNEL(cpy_f16_f16); #undef GGML_METAL_DECL_KERNEL }; @@ -162,6 +165,7 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(get_rows_q5_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); + GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); @@ -171,8 +175,10 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); + GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); + GGML_METAL_ADD_KERNEL(cpy_f16_f16); #undef GGML_METAL_ADD_KERNEL } @@ -250,10 +256,10 @@ bool ggml_metal_add_buffer( if (ctx->buffers[ctx->n_buffers].metal == nil) { fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0); return false; - } else { - fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0); } + fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0); + ++ctx->n_buffers; } @@ -735,6 +741,70 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_NORM: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const float eps = 1e-5f; + + const int nth = 256; + + [encoder setComputePipelineState:ctx->pipeline_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ALIBI: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + GGML_ASSERT((src0t == GGML_TYPE_F32)); + + const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past); + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; + + if (__builtin_popcount(n_head) != 1) { + GGML_ASSERT(false && "only power-of-two n_head implemented"); + } + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + + [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + const int nth = 32; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ROPE: { if (encoder == nil) { @@ -788,6 +858,14 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); }; } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; + case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 09e12a879..d1e49222d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1( (device float *) ((device char *) dst + i*nb1), ne00); } +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + // MEAN + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + const float mean = sum[0]; + + // recenter + device float * y = dst + tgpig*ne00; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; + } + + // VARIANCE + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += y[i00] * y[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + const float variance = sum[0]; + + const float scale = 1.0f/sqrt(variance + eps); + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + + kernel void kernel_rms_norm( device const void * src0, device float * dst, @@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32( } } +kernel void kernel_alibi_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & m0, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float m_k = pow(m0, i2 + 1); + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); + } +} + kernel void kernel_rope( device const void * src0, device float * dst, @@ -540,6 +648,47 @@ kernel void kernel_rope( } } +kernel void kernel_cpy_f16_f16( + device const half * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + kernel void kernel_cpy_f32_f16( device const float * src0, device half * dst, diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 1d4db96ee..95f4cec6d 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -15,6 +15,10 @@ #include "ggml.h" +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + #define CL_DMMV_BLOCK_SIZE 32 #define MULTILINE_QUOTE(...) #__VA_ARGS__ diff --git a/llama.cpp b/llama.cpp index 81f047ed2..a2916b3e8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -886,6 +886,7 @@ static bool kv_cache_init( const int64_t n_elements = n_embd*n_mem; cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + cache.n = 0; struct ggml_init_params params; params.mem_size = cache.buf.size; @@ -904,6 +905,7 @@ static bool kv_cache_init( ggml_set_name(cache.k, "cache_k"); ggml_set_name(cache.v, "cache_v"); + (void) n_gpu_layers; #ifdef GGML_USE_CUBLAS if (n_gpu_layers > n_layer + 1) { ggml_cuda_assign_buffers_no_scratch(cache.v); @@ -1253,7 +1255,7 @@ static void llama_model_load_internal( vram_scratch = n_batch * MB; ggml_cuda_set_scratch_size(vram_scratch); if (n_gpu_layers > 0) { - fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n", + fprintf(stderr, "%s: allocating batch_size x 1 MB = %zd MB VRAM for the scratch buffer\n", __func__, vram_scratch / MB); } }