squash! llama : std::move llm_bigram_bpe from work_queue

Introduced a MovablePriorityQueue class to allow moving elements
out of the priority queue for llm_bigram_bpe.
This commit is contained in:
Daniel Bevenius 2024-08-18 10:44:48 +02:00
parent 6c6db7bcc5
commit 823948cbb8
Failed to extract signature

View file

@ -321,6 +321,21 @@ private:
// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused // TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
template<typename T, typename Container = std::vector<T>, typename Compare = std::less<typename Container::value_type>>
class MovablePriorityQueue : public std::priority_queue<T, Container, Compare> {
public:
using std::priority_queue<T, Container, Compare>::priority_queue;
T pop_move() {
T item = std::move(this->c.front());
std::pop_heap(this->c.begin(), this->c.end(), this->comp);
this->c.pop_back();
return item;
}
void pop() = delete;
};
struct llm_bigram_bpe { struct llm_bigram_bpe {
struct comparator { struct comparator {
bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
@ -329,7 +344,7 @@ struct llm_bigram_bpe {
}; };
using queue_storage = std::vector<llm_bigram_bpe>; using queue_storage = std::vector<llm_bigram_bpe>;
using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>; using queue = MovablePriorityQueue<llm_bigram_bpe, queue_storage, comparator>;
llm_symbol::index left; llm_symbol::index left;
llm_symbol::index right; llm_symbol::index right;
std::string text; std::string text;
@ -520,8 +535,7 @@ struct llm_tokenizer_bpe {
// build token(s) // build token(s)
while (!work_queue.empty()) { while (!work_queue.empty()) {
auto bigram = std::move(const_cast<llm_bigram_bpe&>(work_queue.top())); auto bigram = work_queue.pop_move();
work_queue.pop();
auto & left_symbol = symbols[bigram.left]; auto & left_symbol = symbols[bigram.left];
auto & right_symbol = symbols[bigram.right]; auto & right_symbol = symbols[bigram.right];