Skip to content

Commit

Permalink
bman : remove ubatch member
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Feb 10, 2025
1 parent ef358ee commit d1d8d53
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,9 @@ struct llama_batch_manager_i {

virtual bool is_done() const = 0;
virtual llama_ubatch next() = 0;
virtual bool prepare() = 0;
virtual bool prepare(const llama_ubatch & ubatch) = 0;
virtual void restore() = 0;
virtual void update() = 0;
virtual void update(const llama_ubatch & ubatch) = 0;
virtual void finalize() = 0;

// TODO: might be temporary
Expand Down Expand Up @@ -532,7 +532,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
}

virtual llama_ubatch next() override {
ubatch = llama_ubatch();
llama_ubatch ubatch = llama_ubatch();

const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
Expand All @@ -557,7 +557,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
return ubatch;
}

virtual bool prepare() override {
virtual bool prepare(const llama_ubatch & ubatch) override {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & batch = lctx.sbatch.batch;
Expand Down Expand Up @@ -644,7 +644,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
kv_slot_restorer.restore(lctx.kv_self);
}

virtual void update() override {
virtual void update(const llama_ubatch & ubatch) override {
auto & kv_self = lctx.kv_self;

// update the kv ring buffer
Expand Down Expand Up @@ -682,8 +682,6 @@ struct llama_batch_manager : public llama_batch_manager_i {

const llama_batch & batch;

llama_ubatch ubatch;

llama_kv_slot_restorer kv_slot_restorer;
};

Expand Down Expand Up @@ -728,7 +726,7 @@ int llama_context::decode(llama_batch & inp_batch) {
while (!bman->is_done()) {
llama_ubatch ubatch = bman->next();

if (!bman->prepare()) {
if (!bman->prepare(ubatch)) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
bman->restore();
return -3;
Expand Down Expand Up @@ -782,7 +780,7 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}

bman->update();
bman->update(ubatch);

// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
Expand Down

0 comments on commit d1d8d53

Please sign in to comment.