llama : fix shape prints

This commit is contained in:
Georgi Gerganov 2023-08-16 11:38:17 +03:00
parent 5339b859ec
commit 31fb56e1d3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -993,9 +993,19 @@ static std::string llama_format_tensor_shape(const std::vector<uint32_t> & ne) {
return buf;
}
static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
char buf[256];
snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
for (int i = 1; i < t->n_dims; i++) {
snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), " x %5" PRId64, t->ne[i]);
}
return buf;
}
struct llama_model_loader {
int n_tensors = 0;
int n_created = 0;
bool use_mmap = false;
llama_file file;
@ -1068,7 +1078,6 @@ struct llama_model_loader {
struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) {
struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str());
// TODO: simplify
{
bool is_ok = true;
for (size_t i = 0; i < ne.size(); ++i) {
@ -1079,9 +1088,10 @@ struct llama_model_loader {
}
if (!is_ok) {
throw std::runtime_error(
format("%s: tensor '%s' has wrong shape; expected [%d, %d, %d, %d], got [%d, %d, %d, %d]",
__func__, name.c_str(), ne[0], ne[1], ne[2], ne[3],
(int) cur->ne[0], (int) cur->ne[1], (int) cur->ne[2], (int) cur->ne[3]));
format("%s: tensor '%s' has wrong shape; expected %s, got %s",
__func__, name.c_str(),
llama_format_tensor_shape(ne).c_str(),
llama_format_tensor_shape(cur).c_str()));
}
}
@ -3405,9 +3415,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
tensor->data = read_data.data();
model_loader->load_data_for(tensor);
LLAMA_LOG_INFO("[%4zu/%4zu] %36s - [%5d, %5d], type = %6s, ",
LLAMA_LOG_INFO("[%4zu/%4zu] %36s - [%s], type = %6s, ",
++idx, model_loader->n_tensors,
ggml_get_name(tensor), (int) tensor->ne[0], (int) tensor->ne[1],
ggml_get_name(tensor),
llama_format_tensor_shape(tensor).c_str(),
ggml_type_name(tensor->type));
// This used to be a regex, but <regex> has an extreme cost to compile times.