Merge branch 'master' into gg/llama-kv-cache
ggml-ci
This commit is contained in:
commit
e665b57fa2
6 changed files with 106 additions and 96 deletions
|
@ -1,4 +1,4 @@
|
|||
ARG UBUNTU_VERSION=24.04
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
|
@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget
|
|||
|
||||
# Install Vulkan SDK and cURL
|
||||
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \
|
||||
apt update -y && \
|
||||
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
|
||||
|
||||
|
|
4
.github/workflows/docker.yml
vendored
4
.github/workflows/docker.yml
vendored
|
@ -32,10 +32,12 @@ jobs:
|
|||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# Multi-stage build
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/arm64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false}
|
||||
|
|
|
@ -64,7 +64,9 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
|||
|
||||
if (ctx->mtl_device == nil) {
|
||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||
}
|
||||
|
||||
if (ctx->mtl_device) {
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
|
||||
|
@ -99,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
|||
ctx->mtl_device_ref_count--;
|
||||
|
||||
if (ctx->mtl_device_ref_count == 0) {
|
||||
[ctx->mtl_device release];
|
||||
ctx->mtl_device = nil;
|
||||
if (ctx->mtl_device) {
|
||||
[ctx->mtl_device release];
|
||||
ctx->mtl_device = nil;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <cinttypes>
|
||||
|
||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
|
@ -336,12 +337,55 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
|
|||
}
|
||||
|
||||
struct llama_batch_manager : public llama_batch_manager_i {
|
||||
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
||||
llama_batch_manager(llama_context & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & n_embd = hparams.n_embd;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||
throw std::runtime_error("invalid token");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs_all += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs_all = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ logits_all);
|
||||
|
@ -379,9 +423,29 @@ struct llama_batch_manager : public llama_batch_manager_i {
|
|||
virtual bool prepare() override {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & batch = lctx.sbatch.batch;
|
||||
|
||||
const auto n_tokens_all = batch->n_tokens;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs_all == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
lctx.kv_self_update();
|
||||
|
@ -459,8 +523,8 @@ struct llama_batch_manager : public llama_batch_manager_i {
|
|||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
|
||||
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
|
||||
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch) {
|
||||
return std::make_unique<llama_batch_manager>(*this, batch);
|
||||
}
|
||||
|
||||
enum ggml_status llama_context::compute_graph(
|
||||
|
|
|
@ -28,6 +28,9 @@ struct llama_batch_manager_i {
|
|||
virtual void restore() = 0;
|
||||
virtual void update() = 0;
|
||||
virtual void finalize() = 0;
|
||||
|
||||
// TODO: might be temporary
|
||||
int64_t n_outputs_all = 0;
|
||||
};
|
||||
|
||||
// TODO: make implementation details private
|
||||
|
@ -98,7 +101,7 @@ struct llama_context {
|
|||
void * abort_callback_data = nullptr;
|
||||
|
||||
// TODO: do not pass logits_all explicitly
|
||||
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch, bool logits_all);
|
||||
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch);
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
enum ggml_status compute_graph(
|
||||
|
|
109
src/llama.cpp
109
src/llama.cpp
|
@ -23,6 +23,7 @@
|
|||
#include <cstring>
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <cinttypes>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
|
@ -7751,7 +7752,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
// - inp_batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
|
@ -7774,98 +7775,34 @@ static int llama_decode_impl(
|
|||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & vocab = model.vocab;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
const int32_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
|
||||
uint32_t n_outputs = 0;
|
||||
uint32_t n_outputs_prev = 0;
|
||||
// TODO: try catch
|
||||
auto bman = lctx.prepare_batch(batch);
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs = 1;
|
||||
}
|
||||
const auto n_outputs_all = bman->n_outputs_all;
|
||||
|
||||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||
// TODO: move to batch manager?
|
||||
if (llama_output_reserve(lctx, bman->n_outputs_all) < (size_t) n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||
return -2;
|
||||
};
|
||||
|
||||
const bool logits_all = n_outputs == n_tokens_all;
|
||||
|
||||
//auto & kv_self = lctx.kv_self;
|
||||
//llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||
|
||||
//lctx.sbatch.from_batch(batch, n_embd,
|
||||
// /* simple_split */ !kv_self.recurrent,
|
||||
// /* logits_all */ logits_all);
|
||||
|
||||
auto batch_manager = lctx.prepare_batch(batch, logits_all);
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch = batch_manager->next();
|
||||
llama_ubatch ubatch = bman->next();
|
||||
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
if (!batch_manager->prepare()) {
|
||||
if (!bman->prepare()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||
batch_manager->restore();
|
||||
bman->restore();
|
||||
return -3;
|
||||
}
|
||||
|
||||
|
@ -7927,9 +7864,9 @@ static int llama_decode_impl(
|
|||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||
}
|
||||
|
||||
const auto compute_status = lctx.compute_graph(gf, n_tokens > 1);
|
||||
const auto compute_status = lctx.compute_graph(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
batch_manager->restore();
|
||||
bman->restore();
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
|
@ -7941,7 +7878,7 @@ static int llama_decode_impl(
|
|||
}
|
||||
}
|
||||
|
||||
batch_manager->update();
|
||||
bman->update();
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
|
@ -7958,7 +7895,7 @@ static int llama_decode_impl(
|
|||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
|
||||
}
|
||||
|
@ -7978,7 +7915,7 @@ static int llama_decode_impl(
|
|||
const int32_t n_outputs_new = lctx.n_outputs;
|
||||
|
||||
if (n_outputs_new) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
|
||||
}
|
||||
|
@ -8027,9 +7964,9 @@ static int llama_decode_impl(
|
|||
{
|
||||
bool sorted_output = true;
|
||||
|
||||
GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
|
||||
GGML_ASSERT(lctx.sbatch.out_ids.size() == (size_t) n_outputs_all);
|
||||
|
||||
for (size_t i = 0; i < n_outputs; ++i) {
|
||||
for (size_t i = 0; i < (size_t) n_outputs_all; ++i) {
|
||||
size_t out_id = lctx.sbatch.out_ids[i];
|
||||
lctx.output_ids[out_id] = i;
|
||||
if (out_id != i) {
|
||||
|
@ -8043,12 +7980,12 @@ static int llama_decode_impl(
|
|||
}
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
lctx.n_outputs = n_outputs;
|
||||
lctx.n_outputs = n_outputs_all;
|
||||
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//llama_synchronize(&lctx);
|
||||
|
||||
batch_manager->finalize();
|
||||
bman->finalize();
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue