Merge branch 'master' into gg/llama-kv-cache

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-27 14:00:56 +02:00
commit e665b57fa2
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 106 additions and 96 deletions

View file

@ -1,4 +1,4 @@
ARG UBUNTU_VERSION=24.04
ARG UBUNTU_VERSION=22.04
FROM ubuntu:$UBUNTU_VERSION AS build
@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget
# Install Vulkan SDK and cURL
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \
apt update -y && \
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl

View file

@ -32,10 +32,12 @@ jobs:
env:
COMMIT_SHA: ${{ github.sha }}
strategy:
fail-fast: false
matrix:
config:
# Multi-stage build
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: false}
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, freediskspace: false}
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}

View file

@ -64,7 +64,9 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
}
if (ctx->mtl_device) {
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@ -99,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_device_ref_count--;
if (ctx->mtl_device_ref_count == 0) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
if (ctx->mtl_device) {
[ctx->mtl_device release];
ctx->mtl_device = nil;
}
}
}

View file

@ -7,6 +7,7 @@
#include <cmath>
#include <cstring>
#include <stdexcept>
#include <cinttypes>
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
@ -336,12 +337,55 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
}
struct llama_batch_manager : public llama_batch_manager_i {
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
llama_batch_manager(llama_context & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
const auto & model = lctx.model;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & n_embd = hparams.n_embd;
const auto & kv_self = lctx.kv_self;
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
@ -379,9 +423,29 @@ struct llama_batch_manager : public llama_batch_manager_i {
virtual bool prepare() override {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & batch = lctx.sbatch.batch;
const auto n_tokens_all = batch->n_tokens;
auto & kv_self = lctx.kv_self;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
lctx.kv_self_update();
@ -459,8 +523,8 @@ struct llama_batch_manager : public llama_batch_manager_i {
llama_kv_slot_restorer kv_slot_restorer;
};
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch) {
return std::make_unique<llama_batch_manager>(*this, batch);
}
enum ggml_status llama_context::compute_graph(

View file

@ -28,6 +28,9 @@ struct llama_batch_manager_i {
virtual void restore() = 0;
virtual void update() = 0;
virtual void finalize() = 0;
// TODO: might be temporary
int64_t n_outputs_all = 0;
};
// TODO: make implementation details private
@ -98,7 +101,7 @@ struct llama_context {
void * abort_callback_data = nullptr;
// TODO: do not pass logits_all explicitly
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch, bool logits_all);
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch);
// returns the result of ggml_backend_sched_graph_compute_async execution
enum ggml_status compute_graph(

View file

@ -23,6 +23,7 @@
#include <cstring>
#include <ctime>
#include <functional>
#include <cinttypes>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@ -7751,7 +7752,7 @@ static struct ggml_cgraph * llama_build_graph(
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - batch: batch to evaluate
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
@ -7774,98 +7775,34 @@ static int llama_decode_impl(
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;
const auto & model = lctx.model;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = vocab.n_tokens();
uint32_t n_outputs = 0;
uint32_t n_outputs_prev = 0;
// TODO: try catch
auto bman = lctx.prepare_batch(batch);
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs = n_tokens_all;
} else {
// keep last output only
n_outputs = 1;
}
const auto n_outputs_all = bman->n_outputs_all;
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
// TODO: move to batch manager?
if (llama_output_reserve(lctx, bman->n_outputs_all) < (size_t) n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
return -2;
};
const bool logits_all = n_outputs == n_tokens_all;
//auto & kv_self = lctx.kv_self;
//llama_kv_slot_restorer kv_slot_restorer(kv_self);
//lctx.sbatch.from_batch(batch, n_embd,
// /* simple_split */ !kv_self.recurrent,
// /* logits_all */ logits_all);
auto batch_manager = lctx.prepare_batch(batch, logits_all);
int64_t n_outputs_prev = 0;
while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch = batch_manager->next();
llama_ubatch ubatch = bman->next();
const uint32_t n_tokens = ubatch.n_tokens;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs == n_tokens_all) {
n_outputs_new = n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
if (!batch_manager->prepare()) {
if (!bman->prepare()) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
batch_manager->restore();
bman->restore();
return -3;
}
@ -7927,9 +7864,9 @@ static int llama_decode_impl(
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
const auto compute_status = lctx.compute_graph(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
batch_manager->restore();
bman->restore();
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
@ -7941,7 +7878,7 @@ static int llama_decode_impl(
}
}
batch_manager->update();
bman->update();
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
@ -7958,7 +7895,7 @@ static int llama_decode_impl(
const int32_t n_outputs_new = lctx.n_outputs;
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
}
@ -7978,7 +7915,7 @@ static int llama_decode_impl(
const int32_t n_outputs_new = lctx.n_outputs;
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
@ -8027,9 +7964,9 @@ static int llama_decode_impl(
{
bool sorted_output = true;
GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
GGML_ASSERT(lctx.sbatch.out_ids.size() == (size_t) n_outputs_all);
for (size_t i = 0; i < n_outputs; ++i) {
for (size_t i = 0; i < (size_t) n_outputs_all; ++i) {
size_t out_id = lctx.sbatch.out_ids[i];
lctx.output_ids[out_id] = i;
if (out_id != i) {
@ -8043,12 +7980,12 @@ static int llama_decode_impl(
}
// set to total number of outputs in the batch, for use in llama_get_logits_ith
lctx.n_outputs = n_outputs;
lctx.n_outputs = n_outputs_all;
// wait for the computation to finish (automatically done when obtaining the model output)
//llama_synchronize(&lctx);
batch_manager->finalize();
bman->finalize();
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.