diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 92a49eb74..8a0db74b6 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -319,6 +319,10 @@ class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, const bool progress, std::string * response_str = nullptr) { + if (std::filesystem::exists(output_file)) { + return 0; + } + std::string output_file_partial; curl = curl_easy_init(); if (!curl) { @@ -558,13 +562,14 @@ class LlamaData { } sampler = initialize_sampler(opt); + return 0; } private: #ifdef LLAMA_USE_CURL - int download(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { + int download(const std::string & url, const std::string & output_file, const bool progress, + const std::vector & headers = {}, std::string * response_str = nullptr) { HttpClient http; if (http.init(url, headers, output_file, progress, response_str)) { return 1; @@ -573,48 +578,85 @@ class LlamaData { return 0; } #else - int download(const std::string &, const std::vector &, const std::string &, const bool, + int download(const std::string &, const std::string &, const bool, const std::vector & = {}, std::string * = nullptr) { printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + return 1; } #endif - int huggingface_dl(const std::string & model, const std::vector headers, const std::string & bn) { - // Find the second occurrence of '/' after protocol string - size_t pos = model.find('/'); - pos = model.find('/', pos + 1); - if (pos == std::string::npos) { - return 1; - } - - const std::string hfr = model.substr(0, pos); - const std::string hff = model.substr(pos + 1); - const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; - return download(url, headers, bn, true); - } - - int ollama_dl(std::string & model, const std::vector headers, const std::string & bn) { - if (model.find('/') == std::string::npos) { - model = "library/" + model; - } - - std::string model_tag = "latest"; - size_t colon_pos = model.find(':'); + // Helper function to handle model tag extraction and URL construction + std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { + std::string model_tag = "latest"; + const size_t colon_pos = model.find(':'); if (colon_pos != std::string::npos) { model_tag = model.substr(colon_pos + 1); model = model.substr(0, colon_pos); } - std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag; + std::string url = base_url + model + "/manifests/" + model_tag; + + return { model, url }; + } + + // Helper function to download and parse the manifest + int download_and_parse_manifest(const std::string & url, const std::vector & headers, + nlohmann::json & manifest) { std::string manifest_str; - const int ret = download(manifest_url, headers, "", false, &manifest_str); + int ret = download(url, "", false, headers, &manifest_str); if (ret) { return ret; } - nlohmann::json manifest = nlohmann::json::parse(manifest_str); - std::string layer; + manifest = nlohmann::json::parse(manifest_str); + + return 0; + } + + int huggingface_dl(std::string & model, const std::string & bn) { + // Find the second occurrence of '/' after protocol string + size_t pos = model.find('/'); + pos = model.find('/', pos + 1); + std::string hfr, hff; + std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; + std::string url; + + if (pos == std::string::npos) { + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/"); + hfr = model_name; + + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, headers, manifest); + if (ret) { + return ret; + } + + hff = manifest["ggufFile"]["rfilename"]; + } else { + hfr = model.substr(0, pos); + hff = model.substr(pos + 1); + } + + url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; + + return download(url, bn, true, headers); + } + + int ollama_dl(std::string & model, const std::string & bn) { + const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; + if (model.find('/') == std::string::npos) { + model = "library/" + model; + } + + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, {}, manifest); + if (ret) { + return ret; + } + + std::string layer; for (const auto & l : manifest["layers"]) { if (l["mediaType"] == "application/vnd.ollama.image.model") { layer = l["digest"]; @@ -622,8 +664,9 @@ class LlamaData { } } - std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer; - return download(blob_url, headers, bn, true); + std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; + + return download(blob_url, bn, true, headers); } std::string basename(const std::string & path) { @@ -653,22 +696,18 @@ class LlamaData { return ret; } - const std::string bn = basename(model_); - const std::vector headers = { "--header", - "Accept: application/vnd.docker.distribution.manifest.v2+json" }; + const std::string bn = basename(model_); if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) { rm_until_substring(model_, "://"); - ret = huggingface_dl(model_, headers, bn); + ret = huggingface_dl(model_, bn); } else if (string_starts_with(model_, "hf.co/")) { rm_until_substring(model_, "hf.co/"); - ret = huggingface_dl(model_, headers, bn); - } else if (string_starts_with(model_, "ollama://")) { - rm_until_substring(model_, "://"); - ret = ollama_dl(model_, headers, bn); + ret = huggingface_dl(model_, bn); } else if (string_starts_with(model_, "https://")) { - ret = download(model_, headers, bn, true); - } else { - ret = ollama_dl(model_, headers, bn); + ret = download(model_, bn, true); + } else { // ollama:// or nothing + rm_until_substring(model_, "://"); + ret = ollama_dl(model_, bn); } model_ = bn; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index bb6120568..a66322da0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -46,20 +46,20 @@ #define GGML_CUDA_CC_VOLTA 700 #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_OFFSET_AMD 1000000 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 // GCN/CNDA, wave size is 64 -#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16 -#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue -#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing -#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 -#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000 -#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a -#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA #define GGML_CUDA_CC_QY1 210 #define GGML_CUDA_CC_QY2 220 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a0d6a5496..de3f9c2ca 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -120,6 +120,55 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) #endif } +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int ggml_cuda_parse_id(char devName[]) { + // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp + // these values are not stable so this is susceptible to breakage + // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp + int archMajor = 0x0; + int archMinor = 0x0; + int archNum = GGML_CUDA_CC_OFFSET_AMD; + int archLen = strlen(devName); + char archName[archLen + 1]; + + // strip leading 'gfx' while copying into our buffer + if (archLen > 3) { + strcpy(archName, &devName[3]); + archLen -= 3; + } + + // trim trailing :xnack- or :sramecc- statuses + archLen = strcspn(archName, ":"); + archName[archLen] = '\0'; + + // tease out the version information + if (archLen > 8) { + // versions labeled generic use '-' as delimiter + // strip the trailing "-generic" then iterate through what remains + if ((strstr(archName, "-generic"))) { + archName[archLen - 8] = '\0'; + char * pch; + if ((pch = strtok(archName, "-"))) { + archMajor = (int)strtoul(pch, 0, 16); + if ((pch = strtok(NULL, "-"))) { + archMinor = 0x10 * (int)strtoul(pch, 0, 16); + } + } + } + } else if (archLen >= 3) { + // last two digits should be the minor * 0x10 + stepping + archMinor = (int)strtoul(&archName[archLen - 2], 0, 16); + archName[archLen - 2] = '\0'; + + // only the major version remains + archMajor = (int)strtoul(archName, 0, 16); + } + archNum += archMajor * 0x100; + archNum += archMinor; + return archNum; +} +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: @@ -187,7 +236,6 @@ static ggml_cuda_device_info ggml_cuda_init() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; @@ -196,10 +244,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].smpb = prop.sharedMemPerBlock; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; - info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD; + + info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName); + if ((info.devices[id].cc & 0xff00) == 0x0) { + GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n", + id, prop.name, prop.gcnArchName, prop.major, prop.minor); + + // Fallback to prop.major and prop.minor + if (prop.major > 0) { + info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; + } + } + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no"); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c9474345d..76f8e4291 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -64,7 +64,9 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev if (ctx->mtl_device == nil) { ctx->mtl_device = MTLCreateSystemDefaultDevice(); + } + if (ctx->mtl_device) { ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -99,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_device_ref_count--; if (ctx->mtl_device_ref_count == 0) { - [ctx->mtl_device release]; - ctx->mtl_device = nil; + if (ctx->mtl_device) { + [ctx->mtl_device release]; + ctx->mtl_device = nil; + } } } diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 75073bf61..05d58ad90 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -819,7 +819,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps for (const auto & file : files) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 0782d3a41..561f8bdb8 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1245,8 +1245,13 @@ struct llama_vocab::impl { std::vector cache_special_tokens; std::vector cache_token_to_piece; // llama_token_to_piece(special = true); - - std::map, int> bpe_ranks; + struct pair_hash { + size_t operator()(const std::pair & p) const { + return std::hash{}(p.first) ^ //create some hash for pair + (std::hash{}(p.second) << 1); + } + }; + std::unordered_map, int, pair_hash> bpe_ranks; // set of all tokens that cause "end of generation" std::set special_eog_ids; diff --git a/src/llama.cpp b/src/llama.cpp index 094157ccf..12e8f41fc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8432,13 +8432,141 @@ static enum ggml_status llama_graph_compute( return status; } +static int llama_prepare_sbatch( + llama_context & lctx, + const llama_batch & batch, + uint32_t & n_outputs) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const uint32_t n_tokens_all = batch.n_tokens; + const int64_t n_embd = hparams.n_embd; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + if (batch.token) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return -1; + } + } + } + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + + lctx.n_queued_tokens += n_tokens_all; + lctx.embd_seq.clear(); + + // count outputs + if (batch.logits && !embd_pooled) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + n_outputs += batch.logits[i] != 0; + } + } else if (lctx.logits_all || embd_pooled) { + n_outputs = n_tokens_all; + } else { + // keep last output only + n_outputs = 1; + } + + lctx.sbatch.from_batch(batch, n_embd, + /* simple_split */ !lctx.kv_self.recurrent, + /* logits_all */ n_outputs == n_tokens_all); + + // reserve output buffer + if (llama_output_reserve(lctx, n_outputs) < n_outputs) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); + return -2; + }; + + return 0; +} + +static int llama_prepare_ubatch( + llama_context & lctx, + llama_kv_slot_restorer & kv_slot_restorer, + llama_ubatch & ubatch, + const uint32_t n_outputs, + const uint32_t n_tokens_all) { + GGML_ASSERT(lctx.sbatch.n_tokens > 0); + + auto & kv_self = lctx.kv_self; + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + if (lctx.kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = lctx.sbatch.split_seq(cparams.n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = lctx.sbatch.split_equal(cparams.n_ubatch); + } + } else { + ubatch = lctx.sbatch.split_simple(cparams.n_ubatch); + } + + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + if (n_outputs == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += int32_t(ubatch.output[i] != 0); + } + } + + // needs to happen before the graph is built + lctx.n_outputs = n_outputs_new; + } + + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + llama_kv_cache_update(&lctx); + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { + kv_self.head = 0; + } + + const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); + if (!slot) { + return 1; + } + kv_slot_restorer.save(slot); + + if (!kv_self.recurrent) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + const uint32_t pad = llama_kv_cache_get_padding(cparams); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + } + } + + return 0; +} + // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), // the kv_cache state will be returned to its original state // (for non-recurrent models) or cleaned (for recurrent models) // // - lctx: llama context -// - batch: batch to evaluate +// - inp_batch: batch to evaluate // // return 0 on success // return positive int on warning @@ -8455,37 +8583,18 @@ static int llama_decode_impl( return -1; } - // temporary allocate memory for the input batch if needed + // temporarily allocate memory for the input batch if needed llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); - const llama_batch & batch = batch_allocr.batch; - const uint32_t n_tokens_all = batch.n_tokens; const auto & model = lctx.model; const auto & vocab = model.vocab; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - - if (batch.token) { - for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); - return -1; - } - } - } - - GGML_ASSERT(n_tokens_all <= cparams.n_batch); - - GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); - if (lctx.t_compute_start_us == 0) { lctx.t_compute_start_us = ggml_time_us(); } - lctx.n_queued_tokens += n_tokens_all; - auto & kv_self = lctx.kv_self; llama_kv_slot_restorer kv_slot_restorer(kv_self); @@ -8495,99 +8604,27 @@ static int llama_decode_impl( uint32_t n_outputs = 0; uint32_t n_outputs_prev = 0; - const auto n_ubatch = cparams.n_ubatch; - - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - - lctx.embd_seq.clear(); - - // count outputs - if (batch.logits && !embd_pooled) { - for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs += batch.logits[i] != 0; + { + const int ret = llama_prepare_sbatch(lctx, batch, n_outputs); + if (ret != 0) { + return ret; } - } else if (lctx.logits_all || embd_pooled) { - n_outputs = n_tokens_all; - } else { - // keep last output only - n_outputs = 1; } - lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, - /* logits_all */ n_outputs == n_tokens_all); - - // reserve output buffer - if (llama_output_reserve(lctx, n_outputs) < n_outputs) { - LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); - return -2; - }; - while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; - if (kv_self.recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = lctx.sbatch.split_seq(n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = lctx.sbatch.split_equal(n_ubatch); - } - } else { - ubatch = lctx.sbatch.split_simple(n_ubatch); - } - const uint32_t n_tokens = ubatch.n_tokens; - - // count the outputs in this u_batch { - int32_t n_outputs_new = 0; - - if (n_outputs == n_tokens_all) { - n_outputs_new = n_tokens; - } else { - GGML_ASSERT(ubatch.output); - for (uint32_t i = 0; i < n_tokens; i++) { - n_outputs_new += (int32_t) (ubatch.output[i] != 0); - } + const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens); + if (ret != 0) { + return ret; } - - // needs to happen before the graph is built - lctx.n_outputs = n_outputs_new; } - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; + const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch; GGML_ASSERT(n_threads > 0); - // non-causal masks do not use the KV cache - if (hparams.causal_attn) { - llama_kv_cache_update(&lctx); - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); - if (!slot) { - return 1; - } - kv_slot_restorer.save(slot); - - if (!kv_self.recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = llama_kv_cache_get_padding(cparams); - kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); - } - } - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); ggml_backend_sched_reset(lctx.sched.get()); @@ -8640,7 +8677,7 @@ static int llama_decode_impl( // update the kv ring buffer { - kv_self.head += n_tokens; + kv_self.head += ubatch.n_tokens; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) {