add --check-tensors command line argument

tensor validation is disabled by default and can be enabled by adding
`--check-tensors` to the command line arguments.

quantize always validates tensors.
This commit is contained in:
slaren 2024-04-25 15:41:36 +02:00
parent c806db318d
commit 145d315127
4 changed files with 22 additions and 11 deletions

View file

@ -1089,6 +1089,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.n_print = std::stoi(argv[i]); params.n_print = std::stoi(argv[i]);
return true; return true;
} }
if (arg == "--check-tensors") {
params.check_tensors = true;
return true;
}
if (arg == "--ppl-output-type") { if (arg == "--ppl-output-type") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -1554,6 +1558,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -ptc N, --print-token-count N\n"); printf(" -ptc N, --print-token-count N\n");
printf(" print token count every N tokens (default: %d)\n", params.n_print); printf(" print token count every N tokens (default: %d)\n", params.n_print);
printf(" --check-tensors check model tensor data for invalid values\n");
printf("\n"); printf("\n");
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_print_usage(); log_print_usage();
@ -1774,6 +1779,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.tensor_split = params.tensor_split; mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap; mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock; mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
if (params.kv_overrides.empty()) { if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL; mparams.kv_overrides = NULL;
} else { } else {

View file

@ -161,6 +161,7 @@ struct gpt_params {
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V std::string cache_type_v = "f16"; // KV cache data type for the V

View file

@ -2985,6 +2985,7 @@ struct llama_model_loader {
size_t n_bytes = 0; size_t n_bytes = 0;
bool use_mmap = false; bool use_mmap = false;
bool check_tensors;
llama_files files; llama_files files;
llama_ftype ftype; llama_ftype ftype;
@ -3014,7 +3015,7 @@ struct llama_model_loader {
std::string arch_name; std::string arch_name;
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) { llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
int trace = 0; int trace = 0;
if (getenv("LLAMA_TRACE")) { if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE")); trace = atoi(getenv("LLAMA_TRACE"));
@ -3218,6 +3219,7 @@ struct llama_model_loader {
} }
this->use_mmap = use_mmap; this->use_mmap = use_mmap;
this->check_tensors = check_tensors;
} }
~llama_model_loader() { ~llama_model_loader() {
@ -3473,7 +3475,7 @@ struct llama_model_loader {
file->read_raw(cur->data, ggml_nbytes(cur)); file->read_raw(cur->data, ggml_nbytes(cur));
} }
if (!ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
} }
} }
@ -3514,7 +3516,7 @@ struct llama_model_loader {
buf_mmap = bufs_mmap.at(weight->idx); buf_mmap = bufs_mmap.at(weight->idx);
} }
if (!ggml_validate_row_data(cur->type, (uint8_t *) mapping->addr + weight->offs, n_size)) { if (check_tensors && !ggml_validate_row_data(cur->type, (uint8_t *) mapping->addr + weight->offs, n_size)) {
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
} }
@ -3538,7 +3540,7 @@ struct llama_model_loader {
if (ggml_backend_buffer_is_host(cur->buffer)) { if (ggml_backend_buffer_is_host(cur->buffer)) {
file->seek(weight->offs, SEEK_SET); file->seek(weight->offs, SEEK_SET);
file->read_raw(cur->data, n_size); file->read_raw(cur->data, n_size);
if (!ggml_validate_row_data(cur->type, cur->data, n_size)) { if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, n_size)) {
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
} }
} else { } else {
@ -3546,7 +3548,7 @@ struct llama_model_loader {
file->seek(weight->offs, SEEK_SET); file->seek(weight->offs, SEEK_SET);
file->read_raw(read_buf.data(), n_size); file->read_raw(read_buf.data(), n_size);
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
if (!ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
} }
} }
@ -5981,7 +5983,7 @@ static bool llm_load_tensors(
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
try { try {
llama_model_loader ml(fname, params.use_mmap, params.kv_overrides); llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
model.hparams.vocab_only = params.vocab_only; model.hparams.vocab_only = params.vocab_only;
@ -14459,7 +14461,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides; auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data(); kv_overrides = v->data();
} }
llama_model_loader ml(fname_inp, use_mmap, kv_overrides); llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
ml.init_mappings(false); // no prefetching ml.init_mappings(false); // no prefetching
llama_model model; llama_model model;
@ -14780,7 +14782,7 @@ static int llama_apply_lora_from_file_internal(
std::unique_ptr<llama_model_loader> ml; std::unique_ptr<llama_model_loader> ml;
if (path_base_model) { if (path_base_model) {
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ nullptr)); ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
ml->init_mappings(/*prefetch*/ false); // no prefetching ml->init_mappings(/*prefetch*/ false); // no prefetching
} }
@ -15039,6 +15041,7 @@ struct llama_model_params llama_model_default_params() {
/*.vocab_only =*/ false, /*.vocab_only =*/ false,
/*.use_mmap =*/ true, /*.use_mmap =*/ true,
/*.use_mlock =*/ false, /*.use_mlock =*/ false,
/*.check_tensors =*/ false,
}; };
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL

View file

@ -235,6 +235,7 @@ extern "C" {
bool vocab_only; // only load the vocabulary, no weights bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM bool use_mlock; // force system to keep model in RAM
bool check_tensors; // validate model tensor data
}; };
struct llama_context_params { struct llama_context_params {