Remove multiple shards

This commit is contained in:
Howard Su 2023-06-26 16:08:02 +08:00
parent 0be54f75a6
commit d8147f236d

101
llama.cpp
View file

@ -376,18 +376,11 @@ struct llama_load_tensor_shard {
} }
}; };
enum llama_split_type {
SPLIT_NONE,
SPLIT_BY_COLUMNS,
SPLIT_BY_ROWS
};
struct llama_load_tensor { struct llama_load_tensor {
std::vector<llama_load_tensor_shard> shards; llama_load_tensor_shard first_shard;
std::string name; std::string name;
enum ggml_type type = GGML_TYPE_F32; enum ggml_type type = GGML_TYPE_F32;
llama_split_type split_type = SPLIT_NONE;
std::vector<uint32_t> ne; std::vector<uint32_t> ne;
size_t size; size_t size;
struct ggml_tensor * ggml_tensor = NULL; struct ggml_tensor * ggml_tensor = NULL;
@ -397,58 +390,16 @@ struct llama_load_tensor {
void calc_all() { void calc_all() {
calc_type(); calc_type();
calc_split_type();
calc_ne(); calc_ne();
calc_size(); calc_size();
} }
void calc_type() { void calc_type() {
const auto & first_shard = shards.at(0);
for (const auto & shard : shards) {
if (shard.type != first_shard.type) {
throw std::runtime_error(format("inconsistent tensor shard type in '%s'", name.c_str()));
}
}
type = first_shard.type; type = first_shard.type;
} }
void calc_split_type() {
if (shards.at(0).ne.size() == 1 || // 1D tensors are just duplicated in every file
shards.size() == 1) { // only one file?
split_type = SPLIT_NONE;
} else if (name.find("tok_embeddings.") == 0 ||
name.find(".attention.wo.weight") != std::string::npos ||
name.find(".feed_forward.w2.weight") != std::string::npos) {
split_type = SPLIT_BY_COLUMNS;
} else {
split_type = SPLIT_BY_ROWS;
}
}
void calc_ne() { void calc_ne() {
const auto & first_shard = shards.at(0);
for (const auto & shard : shards) {
if (shard.ne != first_shard.ne) {
throw std::runtime_error(format("inconsistent tensor shard shape in '%s': first was %s, other was %s",
name.c_str(), llama_format_tensor_shape(first_shard.ne).c_str(), llama_format_tensor_shape(shard.ne).c_str()));
}
}
ne = first_shard.ne; ne = first_shard.ne;
LLAMA_ASSERT(shards.size() <= UINT32_MAX);
uint32_t n_shards = (uint32_t) shards.size();
switch (split_type) {
case SPLIT_NONE:
ne = first_shard.ne;
break;
case SPLIT_BY_COLUMNS:
ne = {checked_mul<uint32_t>(first_shard.ne[0], n_shards),
first_shard.ne[1]};
break;
case SPLIT_BY_ROWS:
ne = {first_shard.ne[0],
checked_mul<uint32_t>(first_shard.ne[1], n_shards)};
break;
}
} }
void calc_size() { void calc_size() {
@ -589,7 +540,7 @@ struct llama_file_loader {
idx = tensors_map.tensors.size() - 1; idx = tensors_map.tensors.size() - 1;
tensors_map.name_to_idx.emplace(name, idx); tensors_map.name_to_idx.emplace(name, idx);
} }
tensors_map.tensors.at(idx).shards.push_back(shard); tensors_map.tensors.at(idx).first_shard = shard;
} }
} }
}; };
@ -693,12 +644,10 @@ struct llama_model_loader {
bool alignment_prevents_mmap() { bool alignment_prevents_mmap() {
for (const llama_load_tensor & lt : tensors_map.tensors) { for (const llama_load_tensor & lt : tensors_map.tensors) {
for (const llama_load_tensor_shard & shard : lt.shards) { if (lt.first_shard.file_off & 3) {
if (shard.file_off & 3) {
return true; return true;
} }
} }
}
return false; return false;
} }
@ -708,7 +657,7 @@ struct llama_model_loader {
throw std::runtime_error(std::string("missing tok_embeddings.weight")); throw std::runtime_error(std::string("missing tok_embeddings.weight"));
} }
const llama_load_tensor & lt = tensors_map.tensors.at(it->second); const llama_load_tensor & lt = tensors_map.tensors.at(it->second);
return file_loaders.at(0)->hparams.n_embd / lt.shards.at(0).ne.at(0); return file_loaders.at(0)->hparams.n_embd / lt.first_shard.ne.at(0);
} }
void calc_sizes(size_t * ctx_size_p, size_t * mmapped_size_p) const { void calc_sizes(size_t * ctx_size_p, size_t * mmapped_size_p) const {
@ -830,45 +779,13 @@ struct llama_model_loader {
void load_data_for(llama_load_tensor & lt) { void load_data_for(llama_load_tensor & lt) {
if (use_mmap) { if (use_mmap) {
LLAMA_ASSERT(lt.shards.size() == 1); lt.data = (uint8_t *) mapping->addr + lt.first_shard.file_off;
lt.data = (uint8_t *) mapping->addr + lt.shards.at(0).file_off; } else {
} else if (lt.split_type == SPLIT_NONE) { llama_file & file = file_loaders.at(lt.first_shard.file_idx)->file;
llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file; file.seek(lt.first_shard.file_off, SEEK_SET);
file.seek(lt.shards.at(0).file_off, SEEK_SET);
file.read_raw(lt.data, lt.size); file.read_raw(lt.data, lt.size);
} else if (lt.split_type == SPLIT_BY_ROWS) {
size_t offset = 0;
for (llama_load_tensor_shard & shard : lt.shards) {
llama_file & file = file_loaders.at(shard.file_idx)->file;
file.seek(shard.file_off, SEEK_SET);
file.read_raw(lt.data + offset, shard.size);
offset += shard.size;
}
LLAMA_ASSERT(offset == lt.size);
} else if (lt.split_type == SPLIT_BY_COLUMNS) {
// Let's load the data into temporary buffers to ensure the OS performs large loads.
std::vector<llama_buffer> tmp_bufs(lt.shards.size());
for (size_t i = 0; i < lt.shards.size(); i++) {
llama_load_tensor_shard & shard = lt.shards.at(i);
llama_file & file = file_loaders.at(shard.file_idx)->file;
file.seek(shard.file_off, SEEK_SET);
tmp_bufs.at(i).resize(shard.size);
file.read_raw(tmp_bufs.at(i).addr, shard.size);
}
// Then reshape.
size_t num_rows = lt.ne.at(1);
size_t per_shard_row_size = lt.shards.at(0).size / num_rows;
size_t out_offset = 0;
for (size_t row = 0; row < num_rows; row++) {
for (llama_buffer & tmp_buf : tmp_bufs) {
memcpy(lt.data + out_offset,
tmp_buf.addr + row * per_shard_row_size,
per_shard_row_size);
out_offset += per_shard_row_size;
}
}
LLAMA_ASSERT(out_offset == lt.size);
} }
if (0) { if (0) {
print_checksum(lt); print_checksum(lt);
} }