diff --git a/cmake/FindSIMD.cmake b/cmake/FindSIMD.cmake new file mode 100644 index 000000000..33377ec44 --- /dev/null +++ b/cmake/FindSIMD.cmake @@ -0,0 +1,100 @@ +include(CheckCSourceRuns) + +set(AVX_CODE " + #include + int main() + { + __m256 a; + a = _mm256_set1_ps(0); + return 0; + } +") + +set(AVX512_CODE " + #include + int main() + { + __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0); + __m512i b = a; + __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ); + return 0; + } +") + +set(AVX2_CODE " + #include + int main() + { + __m256i a = {0}; + a = _mm256_abs_epi16(a); + __m256i x; + _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code + return 0; + } +") + +set(FMA_CODE " + #include + int main() + { + __m256 acc = _mm256_setzero_ps(); + const __m256 d = _mm256_setzero_ps(); + const __m256 p = _mm256_setzero_ps(); + acc = _mm256_fmadd_ps( d, p, acc ); + return 0; + } +") + +macro(check_sse type flags) + set(__FLAG_I 1) + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + foreach (__FLAG ${flags}) + if (NOT ${type}_FOUND) + set(CMAKE_REQUIRED_FLAGS ${__FLAG}) + check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I}) + if (HAS_${type}_${__FLAG_I}) + set(${type}_FOUND TRUE CACHE BOOL "${type} support") + set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags") + endif() + math(EXPR __FLAG_I "${__FLAG_I}+1") + endif() + endforeach() + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + if (NOT ${type}_FOUND) + set(${type}_FOUND FALSE CACHE BOOL "${type} support") + set(${type}_FLAGS "" CACHE STRING "${type} flags") + endif() + + mark_as_advanced(${type}_FOUND ${type}_FLAGS) +endmacro() + +# flags are for MSVC only! +check_sse("AVX" " ;/arch:AVX") +if (NOT ${AVX_FOUND}) + set(LLAMA_AVX OFF) +else() + set(LLAMA_AVX ON) +endif() + +check_sse("AVX2" " ;/arch:AVX2") +check_sse("FMA" " ;/arch:AVX2") +if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND})) + set(LLAMA_AVX2 OFF) +else() + set(LLAMA_AVX2 ON) +endif() + +check_sse("AVX512" " ;/arch:AVX512") +if (NOT ${AVX512_FOUND}) + set(LLAMA_AVX512 OFF) +else() + set(LLAMA_AVX512 ON) +endif() diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 47e325fe7..d64bde5c1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -39,10 +39,6 @@ #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess -#define cudaDeviceGetMemPool hipDeviceGetMemPool -#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold -#define cudaMemPoolSetAttribute hipMemPoolSetAttribute -#define cudaMemPool_t hipMemPool_t #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -52,7 +48,6 @@ #define cudaEvent_t hipEvent_t #define cudaEventDestroy hipEventDestroy #define cudaFree hipFree -#define cudaFreeAsync hipFreeAsync #define cudaFreeHost hipHostFree #define cudaGetDevice hipGetDevice #define cudaGetDeviceCount hipGetDeviceCount @@ -60,7 +55,6 @@ #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError #define cudaMalloc hipMalloc -#define cudaMallocFromPoolAsync hipMallocFromPoolAsync #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMemcpy hipMemcpy #define cudaMemcpy2DAsync hipMemcpy2DAsync @@ -187,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cudaError_t err_ = (err); \ if (err_ != cudaSuccess) { \ - int dev_id; \ - cudaGetDevice(&dev_id); \ + int id; \ + cudaGetDevice(&id); \ fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ cudaGetErrorString(err_)); \ - fprintf(stderr, "current device: %d\n", dev_id); \ + fprintf(stderr, "current device: %d\n", id); \ exit(1); \ } \ } while (0) @@ -201,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cublasStatus_t err_ = (err); \ if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int dev_id; \ - cudaGetDevice(&dev_id); \ + int id; \ + cudaGetDevice(&id); \ fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ - fprintf(stderr, "current device: %d\n", dev_id); \ + fprintf(stderr, "current device: %d\n", id); \ exit(1); \ } \ } while (0) @@ -471,7 +465,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MAX_STREAMS 8 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; -static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr }; struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors @@ -5777,16 +5770,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) { - if (g_cudaMemPools[id] == nullptr) { - return ggml_cuda_pool_malloc(size, actual_size); - } - void *ptr; - CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream)); - *actual_size = size; - return ptr; -} - static void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; @@ -5805,13 +5788,6 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { } -static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) { - if (g_cudaMemPools[id] == nullptr) { - return ggml_cuda_pool_free(ptr, actual_size); - } - CUDA_CHECK(cudaFreeAsync(ptr, stream)); -} - void ggml_init_cublas() { static bool initialized = false; @@ -5858,13 +5834,6 @@ void ggml_init_cublas() { // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); - - // configure memory pool - cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id); - if (err == cudaSuccess) { - size_t treshold = UINT64_MAX; - CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold)); - } } // configure logging to stdout @@ -6458,7 +6427,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = row_diff*ne00; - src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream); + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); } const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; @@ -6469,12 +6438,13 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = src1_ncols*ne10; - src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; - size_t dst_f16_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; @@ -6492,15 +6462,14 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); - if (dst_f16_as != 0) { - ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream); - } + ggml_cuda_pool_free(dst_f16, dst_as); if (src0_as != 0) { - ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream); + ggml_cuda_pool_free(src0_as_f16, src0_as); } + if (src1_as != 0) { - ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream); + ggml_cuda_pool_free(src1_as_f16, src1_as); } } else { @@ -6510,7 +6479,7 @@ inline void ggml_cuda_op_mul_mat_cublas( if (src0->type != GGML_TYPE_F32) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); GGML_ASSERT(to_fp32_cuda != nullptr); - src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT + src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); } const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; @@ -6527,7 +6496,7 @@ inline void ggml_cuda_op_mul_mat_cublas( &beta, dst_dd_i, ldc)); if (src0_as != 0) { - ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream); + ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); } } @@ -6950,22 +6919,21 @@ static void ggml_cuda_op_mul_mat( src0_dd[id] = (char *) src0_extra->data_device[id]; } else { const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); - src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream); + src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); } if (src1_on_device && src1_is_contiguous) { src1_ddf[id] = (float *) src1_extra->data_device[id]; } else { - src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream); + src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); } if (convert_src1_to_q8_1) { - const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; - src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream); + src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); if (src1_on_device && src1_is_contiguous) { quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); - // CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaGetLastError()); } } @@ -6973,7 +6941,7 @@ static void ggml_cuda_op_mul_mat( dst_dd[id] = (float *) dst_extra->data_device[id]; } else { const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); - dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream); + dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); } } @@ -7099,6 +7067,24 @@ static void ggml_cuda_op_mul_mat( } } + for (int64_t id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); + + // free buffers again when done + if (src0_as[id] > 0) { + ggml_cuda_pool_free(src0_dd[id], src0_as[id]); + } + if (src1_asf[id] > 0) { + ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); + } + if (src1_asq[id] > 0) { + ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); + } + if (dst_as[id] > 0) { + ggml_cuda_pool_free(dst_dd[id], dst_as[id]); + } + } + // main device waits for all other devices to be finished if (split && g_device_count > 1) { int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; @@ -7116,21 +7102,6 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaDeviceSynchronize()); } - - for (int64_t id = 0; id < g_device_count; ++id) { - if (src0_as[id] > 0) { - ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]); - } - if (src1_asf[id] > 0) { - ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]); - } - if (src1_asq[id] > 0) { - ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]); - } - if (dst_as[id] > 0) { - ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]); - } - } } static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7317,11 +7288,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const GGML_ASSERT(to_fp16_cuda != nullptr); size_t src1_as = 0; - half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream); + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream); + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -7375,8 +7346,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const size_t ptrs_src_s = 0; size_t ptrs_dst_s = 0; - ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream); - ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream); + ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); + ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); dim3 block_dims(ne13, ne12); k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( @@ -7389,6 +7360,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const dst->nb[2], dst->nb[3], r2, r3); CUDA_CHECK(cudaGetLastError()); + CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, @@ -7400,22 +7372,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_GEMM_DEFAULT_TENSOR_OP)); if (ptrs_src_s != 0) { - ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream); + ggml_cuda_pool_free(ptrs_src, ptrs_src_s); } if (ptrs_dst_s != 0) { - ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream); + ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); } } #endif const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); - if (src1_as != 0) { - ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream); - } - if (dst_as != 0) { - ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream); - } + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 727b4e554..a2271d225 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -393,6 +393,7 @@ class TensorNameMap: "layers.{bid}.attention_norm", # llama-pth "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon + "model.layers.{bid}.ln1", # yi ), # Attention norm 2 @@ -464,6 +465,7 @@ class TensorNameMap: "layers.{bid}.ffn_norm", # llama-pth "encoder.layer.{bid}.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon + "model.layers.{bid}.ln2", # yi ), # Feed-forward up diff --git a/llama.cpp b/llama.cpp index cc80dee15..4440fb394 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5196,11 +5196,12 @@ static int llama_decode_internal( // If all tensors can be run on the GPU then using more than 1 thread is detrimental. const bool full_offload_supported = - model.arch == LLM_ARCH_LLAMA || - model.arch == LLM_ARCH_BAICHUAN || - model.arch == LLM_ARCH_FALCON || - model.arch == LLM_ARCH_REFACT || - model.arch == LLM_ARCH_MPT; + model.arch == LLM_ARCH_LLAMA || + model.arch == LLM_ARCH_BAICHUAN || + model.arch == LLM_ARCH_FALCON || + model.arch == LLM_ARCH_REFACT || + model.arch == LLM_ARCH_MPT || + model.arch == LLM_ARCH_STARCODER; const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {