fix: fix alignment

This commit is contained in:
Joan Martinez 2024-05-13 10:27:23 +02:00
parent 0771b175aa
commit 22a0113299
5 changed files with 10 additions and 89 deletions

View file

@ -8241,9 +8241,6 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
// positions of the tokens in the KV cache
struct ggml_tensor * KQ_pos = build_inp_KQ_pos(false);
// iterate layers
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur = inpL;
@ -8386,7 +8383,6 @@ struct llm_build_context {
// output layer norm
cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, cb, il);
// input for next layer
inpL = cur;
}
@ -11506,7 +11502,7 @@ static int llama_decode_internal(
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn || model.arch == LLM_ARCH_JINA_BERT_V2) {
if (hparams.causal_attn) {
llama_kv_cache_update(&lctx);
// if we have enough unused cells before the current head ->
@ -12350,10 +12346,14 @@ struct llm_tokenizer_bpe {
break;
case LLAMA_VOCAB_PRE_TYPE_JINA_V2_ZH:
//TODO: Apply GPT2 + lowercasing
word_collection = unicode_regex_split(text, {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
});
//TODO: Apply lowercasing
{
std::string lowercase_text = text;
std::transform(lowercase_text.begin(), lowercase_text.end(), lowercase_text.begin(), [](unsigned char c){ return std::tolower(c); });
word_collection = unicode_regex_split(lowercase_text, {
"",
});
}
break;
default:
// default regex for BPE tokenization pre-processing
word_collection = unicode_regex_split(text, {

View file

@ -71,7 +71,6 @@ extern "C" {
// pre-tokenization types
enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,

View file

@ -4,7 +4,6 @@
#include <map>
#include <utility>
#include <vector>
#include <unordered_map>
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
@ -15,6 +14,4 @@ extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuati
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
extern const std::unordered_map<uint32_t, std::vector<uint32_t>> unicode_decompose_map;
extern const std::unordered_map<uint32_t, uint32_t> unicode_canonical_class;
extern const std::map<char32_t, char32_t> unicode_map_lowercase;

View file

@ -14,8 +14,6 @@
#include <vector>
#include <locale>
#include <codecvt>
#include <unicode/unistr.h>
#include <unicode/unorm2.h>
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result;
@ -590,68 +588,6 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
throw std::invalid_argument("invalid codepoint");
}
// Function to recursively decompose a string
std::vector<uint32_t> decompose_cpts(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> result;
for (const auto& cpt : cpts) {
auto it = unicode_decompose_map.find(cpt);
if (it != unicode_decompose_map.end()) {
for (const auto& decomp: it->second) {
const auto & inner_result = decompose_cpts({decomp});
result.insert(result.end(), inner_result.begin(), inner_result.end());
}
} else {
result.push_back(cpt);
}
}
return result;
}
// Function to sort subsequences based on canonical class
std::vector<uint32_t> sort_by_canonical_class(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> subsequence;
std::vector<uint32_t> result;
auto compareByCanonicalClass = [&](const uint32_t& a, const uint32_t& b) {
auto cc_a_it = unicode_canonical_class.find(a);
if (cc_a_it != unicode_canonical_class.end()) {
auto cc_b_it = unicode_canonical_class.find(b);
if (cc_b_it != unicode_canonical_class.end()) {
return cc_a_it->second < cc_b_it->second;
}
}
return false;
};
for (const auto& cpt : cpts) {
auto it = unicode_canonical_class.find(cpt);
if (it != unicode_canonical_class.end()) {
if (it->second > 0) {
subsequence.push_back(cpt);
} else {
if (!subsequence.empty()) {
sort(subsequence.begin(), subsequence.end(), compareByCanonicalClass);
for (const auto& codepoint : subsequence) {
result.push_back(codepoint);
}
subsequence.clear();
}
result.push_back(cpt);
}
}
}
if (!subsequence.empty()) {
sort(subsequence.begin(), subsequence.end(), compareByCanonicalClass);
for (const auto& codepoint : subsequence) {
result.push_back(codepoint);
}
}
return result;
}
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> result;
result.reserve(cpts.size());
@ -666,14 +602,6 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
return result;
}
std::vector<uint32_t> unicode_cpts_normalize_nfc(const std::vector<uint32_t> & cpts) {
const auto &decomposed_cpts = decompose_cpts(cpts);
const auto &sorted_sequence = sort_by_canonical_class(decomposed_cpts);
//TODO: Do canonical composition
return sorted_sequence;
}
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result;
size_t offset = 0;

View file

@ -17,9 +17,6 @@ std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
std::vector<uint32_t> unicode_cpts_normalize_nfc(const std::vector<uint32_t> & cpts);
std::vector<uint32_t> decompose_cpts(const std::vector<uint32_t> & cpts);
std::vector<uint32_t> sort_by_canonical_class(const std::vector<uint32_t> & cpts);
int unicode_cpt_type(uint32_t cp);
int unicode_cpt_type(const std::string & utf8);