Simplify load logic

This commit is contained in:
Howard Su 2023-06-26 19:57:04 +08:00
parent 76752668de
commit 712c127773

View file

@ -372,12 +372,6 @@ struct llama_load_tensor {
size_t size; size_t size;
struct ggml_tensor * ggml_tensor = NULL; struct ggml_tensor * ggml_tensor = NULL;
uint8_t * data; uint8_t * data;
llama_load_tensor(const std::string & name) : name(name) {}
void calc_all() {
size = llama_calc_tensor_size(ne, type);
}
}; };
struct llama_load_tensors_map { struct llama_load_tensors_map {
@ -465,17 +459,17 @@ struct llama_file_loader {
} }
void read_tensor_metadata(llama_load_tensors_map & tensors_map) { void read_tensor_metadata(llama_load_tensors_map & tensors_map) {
while (file.tell() < file.size) { while (file.tell() < file.size) {
llama_load_tensor tensor;
uint32_t n_dims = file.read_u32(); uint32_t n_dims = file.read_u32();
uint32_t name_len = file.read_u32(); uint32_t name_len = file.read_u32();
ggml_type type = (enum ggml_type) file.read_u32(); tensor.type = (enum ggml_type) file.read_u32();
std::vector<uint32_t> ne; tensor.ne.resize(n_dims);
ne.resize(n_dims); file.read_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * n_dims);
file.read_raw(ne.data(), sizeof(ne[0]) * n_dims);
std::string name = file.read_string(name_len); std::string name = file.read_string(name_len);
if (n_dims < 1 || n_dims > 2) { if (n_dims < 1 || n_dims > 2) {
throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims)); throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims));
} }
switch (type) { switch (tensor.type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
case GGML_TYPE_F16: case GGML_TYPE_F16:
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -490,7 +484,7 @@ struct llama_file_loader {
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
break; break;
default: { default: {
throw std::runtime_error(format("unrecognized tensor type %u\n", type)); throw std::runtime_error(format("unrecognized tensor type %u\n", tensor.type));
} }
} }
@ -499,23 +493,13 @@ struct llama_file_loader {
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR); file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
} }
auto it = tensors_map.name_to_idx.find(name);
size_t idx;
if (it != tensors_map.name_to_idx.end()) {
idx = it->second;
} else {
tensors_map.tensors.emplace_back(name);
idx = tensors_map.tensors.size() - 1;
tensors_map.name_to_idx.emplace(name, idx);
}
auto tensor = tensors_map.tensors.at(idx);
tensor.ne = ne;
tensor.type = type;
tensor.file_off = file.tell(); tensor.file_off = file.tell();
tensor.name = name;
tensor.calc_all(); tensor.size = llama_calc_tensor_size(tensor.ne, tensor.type);
file.seek(tensor.size, SEEK_CUR); file.seek(tensor.size, SEEK_CUR);
tensors_map.tensors.push_back(tensor);
tensors_map.name_to_idx[name] = tensors_map.tensors.size() - 1;
} }
} }
}; };
@ -602,9 +586,6 @@ struct llama_model_loader {
use_mmap = false; use_mmap = false;
} }
this->use_mmap = use_mmap; this->use_mmap = use_mmap;
for (llama_load_tensor & lt : tensors_map.tensors) {
lt.calc_all();
}
} }
bool alignment_prevents_mmap() { bool alignment_prevents_mmap() {