diff --git a/.devops/tools.sh b/.devops/tools.sh index 9a86e6ea0..41a6b1e55 100755 --- a/.devops/tools.sh +++ b/.devops/tools.sh @@ -13,9 +13,13 @@ elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then exec ./llama-quantize "$@" elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then exec ./llama-cli "$@" +elif [[ "$arg1" == '--bench' || "$arg1" == '-b' ]]; then + exec ./llama-bench "$@" +elif [[ "$arg1" == '--perplexity' || "$arg1" == '-p' ]]; then + exec ./llama-perplexity "$@" elif [[ "$arg1" == '--all-in-one' || "$arg1" == '-a' ]]; then echo "Converting PTH to GGML..." - for i in `ls $1/$2/ggml-model-f16.bin*`; do + for i in $(ls $1/$2/ggml-model-f16.bin*); do if [ -f "${i/f16/q4_0}" ]; then echo "Skip model quantization, it already exists: ${i/f16/q4_0}" else @@ -30,6 +34,10 @@ else echo "Available commands: " echo " --run (-r): Run a model previously converted into ggml" echo " ex: -m /models/7B/ggml-model-q4_0.bin -p \"Building a website can be done in 10 simple steps:\" -n 512" + echo " --bench (-b): Benchmark the performance of the inference for various parameters." + echo " ex: -m model.gguf" + echo " --perplexity (-p): Measure the perplexity of a model over a given text." + echo " ex: -m model.gguf -f file.txt" echo " --convert (-c): Convert a llama model into ggml" echo " ex: --outtype f16 \"/models/7B/\" " echo " --quantize (-q): Optimize with quantization process ggml" diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index b5bd3b6d2..9064f3838 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -1,4 +1,4 @@ -ARG UBUNTU_VERSION=22.04 +ARG UBUNTU_VERSION=24.04 FROM ubuntu:$UBUNTU_VERSION AS build @@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget # Install Vulkan SDK and cURL RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ - wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \ + wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \ apt update -y && \ apt-get install -y vulkan-sdk libcurl4-openssl-dev curl @@ -55,8 +55,9 @@ RUN apt-get update \ git \ python3 \ python3-pip \ - && pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt \ + python3-wheel \ + && pip install --break-system-packages --upgrade setuptools \ + && pip install --break-system-packages -r requirements.txt \ && apt autoremove -y \ && apt clean -y \ && rm -rf /tmp/* /var/tmp/* \ diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 6bf22eb66..6955a7dc8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -28,7 +28,7 @@ jobs: push_to_registry: name: Push Docker image to Docker Hub - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: COMMIT_SHA: ${{ github.sha }} strategy: @@ -36,8 +36,7 @@ jobs: matrix: config: # Multi-stage build - - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} - - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, freediskspace: false} + - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: false} - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} diff --git a/cmake/llama-config.cmake.in b/cmake/llama-config.cmake.in index 40ade96e5..90cbec5b6 100644 --- a/cmake/llama-config.cmake.in +++ b/cmake/llama-config.cmake.in @@ -9,7 +9,7 @@ set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@") set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@") set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@") -find_package(ggml REQUIRED) +find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake) find_library(llama_LIBRARY llama REQUIRED diff --git a/common/arg.cpp b/common/arg.cpp index a6226a34b..f5e9b294f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -877,7 +877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.warmup = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--spm-infill"}, string_format( diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 2c38d1ef6..fe5e7b6d8 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -181,6 +181,10 @@ class Opt { } } + if (model_.empty()){ + return 1; + } + return 0; } @@ -319,6 +323,10 @@ class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, const bool progress, std::string * response_str = nullptr) { + if (std::filesystem::exists(output_file)) { + return 0; + } + std::string output_file_partial; curl = curl_easy_init(); if (!curl) { @@ -346,7 +354,11 @@ class HttpClient { data.file_size = set_resume_point(output_file_partial); set_progress_options(progress, data); set_headers(headers); - perform(url); + CURLcode res = perform(url); + if (res != CURLE_OK){ + printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); + return 1; + } if (!output_file.empty()) { std::filesystem::rename(output_file_partial, output_file); } @@ -411,16 +423,12 @@ class HttpClient { } } - void perform(const std::string & url) { - CURLcode res; + CURLcode perform(const std::string & url) { curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); - res = curl_easy_perform(curl); - if (res != CURLE_OK) { - printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); - } + return curl_easy_perform(curl); } static std::string human_readable_time(double seconds) { @@ -558,13 +566,14 @@ class LlamaData { } sampler = initialize_sampler(opt); + return 0; } private: #ifdef LLAMA_USE_CURL - int download(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { + int download(const std::string & url, const std::string & output_file, const bool progress, + const std::vector & headers = {}, std::string * response_str = nullptr) { HttpClient http; if (http.init(url, headers, output_file, progress, response_str)) { return 1; @@ -573,48 +582,85 @@ class LlamaData { return 0; } #else - int download(const std::string &, const std::vector &, const std::string &, const bool, + int download(const std::string &, const std::string &, const bool, const std::vector & = {}, std::string * = nullptr) { printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + return 1; } #endif - int huggingface_dl(const std::string & model, const std::vector headers, const std::string & bn) { - // Find the second occurrence of '/' after protocol string - size_t pos = model.find('/'); - pos = model.find('/', pos + 1); - if (pos == std::string::npos) { - return 1; - } - - const std::string hfr = model.substr(0, pos); - const std::string hff = model.substr(pos + 1); - const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; - return download(url, headers, bn, true); - } - - int ollama_dl(std::string & model, const std::vector headers, const std::string & bn) { - if (model.find('/') == std::string::npos) { - model = "library/" + model; - } - - std::string model_tag = "latest"; - size_t colon_pos = model.find(':'); + // Helper function to handle model tag extraction and URL construction + std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { + std::string model_tag = "latest"; + const size_t colon_pos = model.find(':'); if (colon_pos != std::string::npos) { model_tag = model.substr(colon_pos + 1); model = model.substr(0, colon_pos); } - std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag; + std::string url = base_url + model + "/manifests/" + model_tag; + + return { model, url }; + } + + // Helper function to download and parse the manifest + int download_and_parse_manifest(const std::string & url, const std::vector & headers, + nlohmann::json & manifest) { std::string manifest_str; - const int ret = download(manifest_url, headers, "", false, &manifest_str); + int ret = download(url, "", false, headers, &manifest_str); if (ret) { return ret; } - nlohmann::json manifest = nlohmann::json::parse(manifest_str); - std::string layer; + manifest = nlohmann::json::parse(manifest_str); + + return 0; + } + + int huggingface_dl(std::string & model, const std::string & bn) { + // Find the second occurrence of '/' after protocol string + size_t pos = model.find('/'); + pos = model.find('/', pos + 1); + std::string hfr, hff; + std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; + std::string url; + + if (pos == std::string::npos) { + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/"); + hfr = model_name; + + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, headers, manifest); + if (ret) { + return ret; + } + + hff = manifest["ggufFile"]["rfilename"]; + } else { + hfr = model.substr(0, pos); + hff = model.substr(pos + 1); + } + + url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; + + return download(url, bn, true, headers); + } + + int ollama_dl(std::string & model, const std::string & bn) { + const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; + if (model.find('/') == std::string::npos) { + model = "library/" + model; + } + + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, {}, manifest); + if (ret) { + return ret; + } + + std::string layer; for (const auto & l : manifest["layers"]) { if (l["mediaType"] == "application/vnd.ollama.image.model") { layer = l["digest"]; @@ -622,8 +668,34 @@ class LlamaData { } } - std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer; - return download(blob_url, headers, bn, true); + std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; + + return download(blob_url, bn, true, headers); + } + + int github_dl(const std::string & model, const std::string & bn) { + std::string repository = model; + std::string branch = "main"; + const size_t at_pos = model.find('@'); + if (at_pos != std::string::npos) { + repository = model.substr(0, at_pos); + branch = model.substr(at_pos + 1); + } + + const std::vector repo_parts = string_split(repository, "/"); + if (repo_parts.size() < 3) { + printe("Invalid GitHub repository format\n"); + return 1; + } + + const std::string & org = repo_parts[0]; + const std::string & project = repo_parts[1]; + std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch; + for (size_t i = 2; i < repo_parts.size(); ++i) { + url += "/" + repo_parts[i]; + } + + return download(url, bn, true); } std::string basename(const std::string & path) { @@ -653,22 +725,23 @@ class LlamaData { return ret; } - const std::string bn = basename(model_); - const std::vector headers = { "--header", - "Accept: application/vnd.docker.distribution.manifest.v2+json" }; - if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) { - rm_until_substring(model_, "://"); - ret = huggingface_dl(model_, headers, bn); - } else if (string_starts_with(model_, "hf.co/")) { + const std::string bn = basename(model_); + if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") || + string_starts_with(model_, "hf.co/")) { rm_until_substring(model_, "hf.co/"); - ret = huggingface_dl(model_, headers, bn); - } else if (string_starts_with(model_, "ollama://")) { rm_until_substring(model_, "://"); - ret = ollama_dl(model_, headers, bn); - } else if (string_starts_with(model_, "https://")) { - ret = download(model_, headers, bn, true); - } else { - ret = ollama_dl(model_, headers, bn); + ret = huggingface_dl(model_, bn); + } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) && + !string_starts_with(model_, "https://ollama.com/library/")) { + ret = download(model_, bn, true); + } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) { + rm_until_substring(model_, "github:"); + rm_until_substring(model_, "://"); + ret = github_dl(model_, bn); + } else { // ollama:// or nothing + rm_until_substring(model_, "ollama.com/library/"); + rm_until_substring(model_, "://"); + ret = ollama_dl(model_, bn); } model_ = bn; diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index c1fc12462..0ed5b99be 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -87,7 +87,7 @@ def test_completion_stream_vs_non_stream(): assert content_stream == res_non_stream.body["content"] -def test_completion_stream_with_openai_library(): +def test_completion_with_openai_library(): global server server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") @@ -102,7 +102,7 @@ def test_completion_stream_with_openai_library(): assert match_regex("(going|bed)+", res.choices[0].text) -def test_completion_with_openai_library(): +def test_completion_stream_with_openai_library(): global server server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 8d2b948fb..566709135 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -308,7 +308,7 @@ if (GGML_CPU_ALL_VARIANTS) # MSVC doesn't support AMX ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) endif() -else () +elseif (GGML_CPU) ggml_add_cpu_backend_variant_impl("") endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 9e627da8c..e809f05d2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1302,7 +1302,7 @@ struct ggml_threadpool { // these are atomic as an annotation for thread-sanitizer atomic_bool stop; // Used for stopping the threadpool altogether atomic_bool pause; // Used for pausing the threadpool or individual threads - atomic_bool abort; // Used for aborting processing of a graph + atomic_int abort; // Used for aborting processing of a graph struct ggml_compute_state * workers; // per thread state int n_threads_max; // number of threads in the pool @@ -13851,14 +13851,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.threadpool=*/ tp, }; - for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) { + for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - tp->abort = true; + atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed); tp->ec = GGML_STATUS_ABORTED; } @@ -14031,7 +14031,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl( threadpool->current_chunk = 0; threadpool->stop = false; threadpool->pause = tpp->paused; - threadpool->abort = false; + threadpool->abort = -1; threadpool->workers = NULL; threadpool->n_threads_max = tpp->n_threads; threadpool->n_threads_cur = tpp->n_threads; @@ -14110,7 +14110,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl threadpool->cgraph = cgraph; threadpool->cplan = cplan; threadpool->current_chunk = 0; - threadpool->abort = false; + threadpool->abort = -1; threadpool->ec = GGML_STATUS_SUCCESS; } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index bb6120568..a66322da0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -46,20 +46,20 @@ #define GGML_CUDA_CC_VOLTA 700 #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_OFFSET_AMD 1000000 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 // GCN/CNDA, wave size is 64 -#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16 -#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue -#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing -#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 -#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000 -#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a -#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 85178abd2..de3f9c2ca 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -119,12 +120,78 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) #endif } +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int ggml_cuda_parse_id(char devName[]) { + // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp + // these values are not stable so this is susceptible to breakage + // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp + int archMajor = 0x0; + int archMinor = 0x0; + int archNum = GGML_CUDA_CC_OFFSET_AMD; + int archLen = strlen(devName); + char archName[archLen + 1]; + + // strip leading 'gfx' while copying into our buffer + if (archLen > 3) { + strcpy(archName, &devName[3]); + archLen -= 3; + } + + // trim trailing :xnack- or :sramecc- statuses + archLen = strcspn(archName, ":"); + archName[archLen] = '\0'; + + // tease out the version information + if (archLen > 8) { + // versions labeled generic use '-' as delimiter + // strip the trailing "-generic" then iterate through what remains + if ((strstr(archName, "-generic"))) { + archName[archLen - 8] = '\0'; + char * pch; + if ((pch = strtok(archName, "-"))) { + archMajor = (int)strtoul(pch, 0, 16); + if ((pch = strtok(NULL, "-"))) { + archMinor = 0x10 * (int)strtoul(pch, 0, 16); + } + } + } + } else if (archLen >= 3) { + // last two digits should be the minor * 0x10 + stepping + archMinor = (int)strtoul(&archName[archLen - 2], 0, 16); + archName[archLen - 2] = '\0'; + + // only the major version remains + archMajor = (int)strtoul(archName, 0, 16); + } + archNum += archMajor * 0x100; + archNum += archMinor; + return archNum; +} +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 - rocblas_initialize(); - CUDA_CHECK(cudaDeviceSynchronize()); + { + int major_version = 0; + size_t version_length = 0; + if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { + std::string version(version_length, '\0'); + if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { + version.resize(::strlen(version.c_str())); + int parsed_value = 0; + if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { + major_version = parsed_value; + } + } + } + if (major_version < 4) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n"); + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); + } + } #endif ggml_cuda_device_info info = {}; @@ -169,7 +236,6 @@ static ggml_cuda_device_info ggml_cuda_init() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; @@ -178,10 +244,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].smpb = prop.sharedMemPerBlock; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; - info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD; + + info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName); + if ((info.devices[id].cc & 0xff00) == 0x0) { + GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n", + id, prop.name, prop.gcnArchName, prop.major, prop.minor); + + // Fallback to prop.major and prop.minor + if (prop.major > 0) { + info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; + } + } + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no"); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 9aa4b8489..da377200e 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -13,6 +13,12 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif template static __global__ void soft_max_f32( const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, @@ -118,6 +124,9 @@ static __global__ void soft_max_f32( dst[col] = vals[col] * inv_sum; } } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif static __global__ void soft_max_back_f32( const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ed4d8bb8b..2984ed82e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); } -static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max); -} - static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope); @@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens ggml_sycl_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - ggml_sycl_soft_max(ctx, dst); + ggml_sycl_op_soft_max(ctx, dst); break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index a9b3fce0d..563e0655f 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,7 +1,7 @@ -#include "norm.hpp" +#include "softmax.hpp" -template -static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const slope = sycl::pow(base, float(exp)); } - float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; + float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; float max_val = -INFINITY; for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); max_val = buf[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); + max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); } max_val = warp_reduce_max(max_val, item_ct1); } @@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const } } -template -static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, queue_ptr stream) { @@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * }); } -static void soft_max_f32_sycl(const float * x, const float * mask, +template +static void soft_max_f32_sycl(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, queue_ptr stream, int device) { @@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask, } } -void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows_x = ggml_nrows(dst->src[0]); + const int64_t nrows_y = dst->src[0]->ne[1]; float scale = 1.0f; float max_bias = 0.0f; @@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, - nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + ggml_sycl_set_device(ctx.device); + dpct::queue_ptr main_stream = ctx.stream(); + + if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { + const sycl::half * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, + main_stream, ctx.device); + } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { + const float * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } else { + /* mask unavailable */ + soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } } diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index bdb8f712e..2cf8582ec 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,10 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream); +void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); #endif // GGML_SYCL_SOFTMAX_HPP diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 92c4294c5..3b4861542 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -128,6 +128,10 @@ static void ggml_print_backtrace_symbols(void) { #endif static void ggml_print_backtrace(void) { + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return; + } char attach[32]; snprintf(attach, sizeof(attach), "attach %d", getpid()); int pid = fork(); diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index cfba59d32..ddb9d817e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -d92321c0d151fe73a47d89738c7c3091ac904297 +32f0b85987396945afea2291d5f4c5862434292b diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 75073bf61..05d58ad90 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -819,7 +819,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps for (const auto & file : files) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 0782d3a41..561f8bdb8 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1245,8 +1245,13 @@ struct llama_vocab::impl { std::vector cache_special_tokens; std::vector cache_token_to_piece; // llama_token_to_piece(special = true); - - std::map, int> bpe_ranks; + struct pair_hash { + size_t operator()(const std::pair & p) const { + return std::hash{}(p.first) ^ //create some hash for pair + (std::hash{}(p.second) << 1); + } + }; + std::unordered_map, int, pair_hash> bpe_ranks; // set of all tokens that cause "end of generation" std::set special_eog_ids; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 468016403..4c5c4dd9c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2347,11 +2347,12 @@ struct test_soft_max : public test_case { const ggml_type type; const std::array ne; const bool mask; + const ggml_type m_prec; const float scale; const float max_bias; std::string vars() override { - return VARS_TO_STR5(type, ne, mask, scale, max_bias); + return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias); } // the 1024 test with bias occasionally fails: @@ -2363,9 +2364,10 @@ struct test_soft_max : public test_case { test_soft_max(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 5, 4, 3}, bool mask = false, + ggml_type m_prec = GGML_TYPE_F32, float scale = 1.0f, float max_bias = 0.0f) - : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {} + : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); @@ -2374,7 +2376,7 @@ struct test_soft_max : public test_case { ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); + mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]); ggml_set_name(mask, "mask"); } @@ -4150,17 +4152,28 @@ static std::vector> make_test_cases_eval() { for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); + if (mask) { + for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias)); + } + } else { + /* The precision of mask here doesn't matter as boolean mask is false */ + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); + } } } } } } - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f)); for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { @@ -4296,13 +4309,13 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));