diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5d21dd5ef2cb3..4387128fedf15 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -460,9 +460,9 @@ struct llama_batch_manager_i { virtual bool is_done() const = 0; virtual llama_ubatch next() = 0; - virtual bool prepare() = 0; + virtual bool prepare(const llama_ubatch & ubatch) = 0; virtual void restore() = 0; - virtual void update() = 0; + virtual void update(const llama_ubatch & ubatch) = 0; virtual void finalize() = 0; // TODO: might be temporary @@ -532,7 +532,7 @@ struct llama_batch_manager : public llama_batch_manager_i { } virtual llama_ubatch next() override { - ubatch = llama_ubatch(); + llama_ubatch ubatch = llama_ubatch(); const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -557,7 +557,7 @@ struct llama_batch_manager : public llama_batch_manager_i { return ubatch; } - virtual bool prepare() override { + virtual bool prepare(const llama_ubatch & ubatch) override { const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; const auto & batch = lctx.sbatch.batch; @@ -644,7 +644,7 @@ struct llama_batch_manager : public llama_batch_manager_i { kv_slot_restorer.restore(lctx.kv_self); } - virtual void update() override { + virtual void update(const llama_ubatch & ubatch) override { auto & kv_self = lctx.kv_self; // update the kv ring buffer @@ -682,8 +682,6 @@ struct llama_batch_manager : public llama_batch_manager_i { const llama_batch & batch; - llama_ubatch ubatch; - llama_kv_slot_restorer kv_slot_restorer; }; @@ -728,7 +726,7 @@ int llama_context::decode(llama_batch & inp_batch) { while (!bman->is_done()) { llama_ubatch ubatch = bman->next(); - if (!bman->prepare()) { + if (!bman->prepare(ubatch)) { LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); bman->restore(); return -3; @@ -782,7 +780,7 @@ int llama_context::decode(llama_batch & inp_batch) { } } - bman->update(); + bman->update(ubatch); // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) {