bman : remove ubatch member

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-02-10 16:50:14 +02:00
parent ef358ee78f
commit d1d8d53008
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -460,9 +460,9 @@ struct llama_batch_manager_i {
virtual bool is_done() const = 0;
virtual llama_ubatch next() = 0;
virtual bool prepare() = 0;
virtual bool prepare(const llama_ubatch & ubatch) = 0;
virtual void restore() = 0;
virtual void update() = 0;
virtual void update(const llama_ubatch & ubatch) = 0;
virtual void finalize() = 0;
// TODO: might be temporary
@ -532,7 +532,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
}
virtual llama_ubatch next() override {
ubatch = llama_ubatch();
llama_ubatch ubatch = llama_ubatch();
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
@ -557,7 +557,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
return ubatch;
}
virtual bool prepare() override {
virtual bool prepare(const llama_ubatch & ubatch) override {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & batch = lctx.sbatch.batch;
@ -644,7 +644,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
kv_slot_restorer.restore(lctx.kv_self);
}
virtual void update() override {
virtual void update(const llama_ubatch & ubatch) override {
auto & kv_self = lctx.kv_self;
// update the kv ring buffer
@ -682,8 +682,6 @@ struct llama_batch_manager : public llama_batch_manager_i {
const llama_batch & batch;
llama_ubatch ubatch;
llama_kv_slot_restorer kv_slot_restorer;
};
@ -728,7 +726,7 @@ int llama_context::decode(llama_batch & inp_batch) {
while (!bman->is_done()) {
llama_ubatch ubatch = bman->next();
if (!bman->prepare()) {
if (!bman->prepare(ubatch)) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
bman->restore();
return -3;
@ -782,7 +780,7 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}
bman->update();
bman->update(ubatch);
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {