Merge 'origin/master' into cistuff

This commit is contained in:
Henri Vasserman 2023-05-06 16:57:21 +03:00
commit 71fac5bbcb
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986
10 changed files with 225 additions and 33 deletions

View file

@ -262,6 +262,82 @@ jobs:
path: | path: |
llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip
windows-latest-cmake-cublas:
runs-on: windows-latest
strategy:
matrix:
cuda: ['12.1.0', '11.7.1']
build: ['cublas']
steps:
- name: Clone
id: checkout
uses: actions/checkout@v1
- uses: Jimver/cuda-toolkit@v0.2.10
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda }}
# TODO(green-sky): _dev seems to fail, and non dev are not enought
#sub-packages: '["nvcc", "cudart", "cublas", "cudart_dev", "cublas_dev"]'
- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. -DLLAMA_CUBLAS=ON
cmake --build . --config Release
- name: Get commit hash
id: commit
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: pr-mpt/actions-commit-hash@v2
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
7z a llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-cu${{ matrix.cuda }}-x64.zip .\build\bin\Release\*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v3
with:
path: |
llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-cu${{ matrix.cuda }}-x64.zip
- name: Copy and pack Cuda runtime
if: ${{ matrix.cuda == '12.1.0' }}
# TODO(green-sky): paths are cuda 12 specific
run: |
echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
mkdir '.\build\bin\cudart\'
cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cudart64_12.dll" '.\build\bin\cudart\'
cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublas64_12.dll" '.\build\bin\cudart\'
cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublasLt64_12.dll" '.\build\bin\cudart\'
7z a cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip .\build\bin\cudart\*
- name: Copy and pack Cuda runtime
if: ${{ matrix.cuda == '11.7.1' }}
# TODO(green-sky): paths are cuda 11 specific
run: |
echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}"
mkdir '.\build\bin\cudart\'
ls "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin"
cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cudart64_110.dll" '.\build\bin\cudart\'
cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublas64_11.dll" '.\build\bin\cudart\'
cp "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin\cublasLt64_11.dll" '.\build\bin\cudart\'
7z a cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip .\build\bin\cudart\*
- name: Upload Cuda runtime
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v3
with:
path: |
cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip
release: release:
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
@ -273,6 +349,7 @@ jobs:
- macOS-latest-make - macOS-latest-make
- macOS-latest-cmake - macOS-latest-cmake
- windows-latest-cmake - windows-latest-cmake
- windows-latest-cmake-cublas
steps: steps:
- name: Download artifacts - name: Download artifacts

View file

@ -107,7 +107,11 @@ ifndef LLAMA_NO_ACCELERATE
endif endif
ifdef LLAMA_OPENBLAS ifdef LLAMA_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),)
LDFLAGS += -lopenblas -lcblas
else
LDFLAGS += -lopenblas
endif
endif endif
ifdef LLAMA_CUBLAS ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
@ -121,7 +125,12 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
endif endif
ifdef LLAMA_CLBLAST ifdef LLAMA_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST CFLAGS += -DGGML_USE_CLBLAST
LDFLAGS += -lclblast -lOpenCL # Mac provides OpenCL as a framework
ifeq ($(UNAME_S),Darwin)
LDFLAGS += -lclblast -framework OpenCL
else
LDFLAGS += -lclblast -lOpenCL
endif
OBJS += ggml-opencl.o OBJS += ggml-opencl.o
ggml-opencl.o: ggml-opencl.c ggml-opencl.h ggml-opencl.o: ggml-opencl.c ggml-opencl.h
$(CC) $(CFLAGS) -c $< -o $@ $(CC) $(CFLAGS) -c $< -o $@

View file

@ -18,10 +18,12 @@ The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quant
- Plain C/C++ implementation without dependencies - Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via ARM NEON and Accelerate framework - Apple silicon first-class citizen - optimized via ARM NEON and Accelerate framework
- AVX2 support for x86 architectures - AVX, AVX2 and AVX512 support for x86 architectures
- Mixed F16 / F32 precision - Mixed F16 / F32 precision
- 4-bit integer quantization support - 4-bit, 5-bit and 8-bit integer quantization support
- Runs on the CPU - Runs on the CPU
- OpenBLAS support
- cuBLAS and CLBlast support
The original implementation of `llama.cpp` was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022). The original implementation of `llama.cpp` was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022).
Since then, the project has improved significantly thanks to many contributions. This project is for educational purposes and serves Since then, the project has improved significantly thanks to many contributions. This project is for educational purposes and serves
@ -43,6 +45,7 @@ as the main playground for developing new features for the [ggml](https://github
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne) - [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
- [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894) - [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894)
- [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) - [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/)
- [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy)
**Bindings:** **Bindings:**
@ -213,7 +216,6 @@ Building the program with BLAS support may lead to some performance improvements
```bash ```bash
make LLAMA_OPENBLAS=1 make LLAMA_OPENBLAS=1
``` ```
Note: In order to build on Arch Linux with OpenBLAS support enabled you must edit the Makefile adding at the end of the line 105: `-lcblas`
- On Windows: - On Windows:

View file

@ -67,6 +67,7 @@ FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()} {ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()}
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
DT_BF16: np.dtype(np.uint16),
DT_F16: np.dtype(np.float16), DT_F16: np.dtype(np.float16),
DT_F32: np.dtype(np.float32), DT_F32: np.dtype(np.float32),
DT_I32: np.dtype(np.int32), DT_I32: np.dtype(np.int32),
@ -276,6 +277,12 @@ class Tensor(metaclass=ABCMeta):
def to_ggml(self) -> 'GGMLCompatibleTensor': ... def to_ggml(self) -> 'GGMLCompatibleTensor': ...
def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray:
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
fp32_arr = bf16_arr.astype(np.uint32) << 16
return fp32_arr.view(np.float32)
class UnquantizedTensor(Tensor): class UnquantizedTensor(Tensor):
def __init__(self, ndarray: NDArray) -> None: def __init__(self, ndarray: NDArray) -> None:
assert isinstance(ndarray, np.ndarray) assert isinstance(ndarray, np.ndarray)
@ -284,6 +291,8 @@ class UnquantizedTensor(Tensor):
def astype(self, data_type: DataType) -> Tensor: def astype(self, data_type: DataType) -> Tensor:
dtype = DATA_TYPE_TO_NUMPY[data_type] dtype = DATA_TYPE_TO_NUMPY[data_type]
if self.data_type == DT_BF16:
self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype)) return UnquantizedTensor(self.ndarray.astype(dtype))
def to_ggml(self) -> 'UnquantizedTensor': def to_ggml(self) -> 'UnquantizedTensor':
@ -686,6 +695,7 @@ class LazyUnpickler(pickle.Unpickler):
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
return LazyStorage(load=load, kind=pid[1], description=description) return LazyStorage(load=load, kind=pid[1], description=description)
# @staticmethod
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, # pyright: ignore[reportSelfClsParameterName] def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, # pyright: ignore[reportSelfClsParameterName]
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
assert isinstance(storage, LazyStorage) assert isinstance(storage, LazyStorage)
@ -696,12 +706,18 @@ class LazyUnpickler(pickle.Unpickler):
description = f'pickled storage_offset={storage_offset} in {storage.description}' description = f'pickled storage_offset={storage_offset} in {storage.description}'
return LazyTensor(load, list(size), storage.kind.data_type, description) return LazyTensor(load, list(size), storage.kind.data_type, description)
# @staticmethod
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
CLASSES: Dict[Any, Any] = { CLASSES: Dict[Any, Any] = {
('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2,
('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2, ('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2,
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
('torch', 'HalfStorage'): LazyStorageKind(DT_F16), ('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
('torch', 'FloatStorage'): LazyStorageKind(DT_F32), ('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
('torch', 'IntStorage'): LazyStorageKind(DT_I32), ('torch', 'IntStorage'): LazyStorageKind(DT_I32),
('torch', 'Tensor'): LazyTensor,
} }
def find_class(self, module: str, name: str) -> Any: def find_class(self, module: str, name: str) -> Any:
@ -961,7 +977,7 @@ class OutputFile:
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
wq_type = model["layers.0.attention.wq.weight"].data_type wq_type = model["layers.0.attention.wq.weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
return GGMLFileType.AllF32 return GGMLFileType.AllF32
if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16):
return GGMLFileType.MostlyF16 return GGMLFileType.MostlyF16

View file

@ -324,6 +324,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.input_prefix = argv[i]; params.input_prefix = argv[i];
} else if (arg == "--in-suffix") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.input_suffix = argv[i];
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params); gpt_print_usage(argc, argv, default_params);
@ -362,6 +368,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " --random-prompt start with a randomized prompt.\n"); fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);

View file

@ -43,6 +43,7 @@ struct gpt_params {
std::string prompt = ""; std::string prompt = "";
std::string path_session = ""; // path to file for saving/loading model eval state std::string path_session = ""; // path to file for saving/loading model eval state
std::string input_prefix = ""; // string to prefix user inputs with std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path std::string lora_adapter = ""; // lora adapter path

View file

@ -112,6 +112,14 @@ The `--in-prefix` flag is used to add a prefix to your input, primarily, this is
./main -r "User:" --in-prefix " " ./main -r "User:" --in-prefix " "
``` ```
### In-Suffix
The `--in-suffix` flag is used to add a suffix after your input. This is useful for adding an "Assistant:" prompt after the user's input. It's added after the new-line character (`\n`) that's automatically added to the end of the user's input. Here's an example of how to use the `--in-suffix` flag in conjunction with the `--reverse-prompt` flag:
```sh
./main -r "User:" --in-prefix " " --in-suffix "Assistant:"
```
### Instruction Mode ### Instruction Mode
Instruction mode is particularly useful when working with Alpaca models, which are designed to follow user instructions for specific tasks: Instruction mode is particularly useful when working with Alpaca models, which are designed to follow user instructions for specific tasks:

View file

@ -260,6 +260,10 @@ int main(int argc, char ** argv) {
if (!params.input_prefix.empty()) { if (!params.input_prefix.empty()) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
} }
if (!params.input_suffix.empty()) {
fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str());
}
} }
fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
@ -567,6 +571,11 @@ int main(int argc, char ** argv) {
// Add tokens to embd only if the input buffer is non-empty // Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back // Entering a empty line lets the user pass control back
if (buffer.length() > 1) { if (buffer.length() > 1) {
// append input suffix if any
if (!params.input_suffix.empty()) {
buffer += params.input_suffix;
printf("%s", params.input_suffix.c_str());
}
// instruct mode: insert instruction prefix // instruct mode: insert instruction prefix
if (params.instruct && !is_antiprompt) { if (params.instruct && !is_antiprompt) {

View file

@ -6,23 +6,47 @@
#include <map> #include <map>
#include <string> #include <string>
static const std::map<std::string, enum llama_ftype> LLAMA_FTYPE_MAP = { static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
{"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0}, {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0},
{"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1},
{"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2}, {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2},
{"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0}, {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0},
{"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1}, {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
{"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
}; };
bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) {
auto it = LLAMA_FTYPE_MAP.find(ftype_str);
if (it != LLAMA_FTYPE_MAP.end()) {
ftype = it->second;
ftype_str_out = it->first;
return true;
}
// try to parse as an integer
try {
int ftype_int = std::stoi(ftype_str);
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
if (it->second == ftype_int) {
ftype = it->second;
ftype_str_out = it->first;
return true;
}
}
}
catch (...) {
// stoi failed
}
return false;
}
// usage: // usage:
// ./quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
// //
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_time_init(); ggml_time_init();
if (argc < 4) { if (argc < 3) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type [nthread]\n", argv[0]); fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
} }
@ -36,24 +60,62 @@ int main(int argc, char ** argv) {
ggml_free(ctx); ggml_free(ctx);
} }
// parse command line arguments
const std::string fname_inp = argv[1]; const std::string fname_inp = argv[1];
const std::string fname_out = argv[2]; std::string fname_out;
int nthread;
llama_ftype ftype;
enum llama_ftype ftype; int arg_idx = 2;
if (argv[3][0] == 'q') { std::string ftype_str;
auto it = LLAMA_FTYPE_MAP.find(argv[3]); if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
if (it == LLAMA_FTYPE_MAP.end()) { // argv[2] is the ftype
fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, argv[3]); std::string fpath;
const size_t pos = fname_inp.find_last_of('/');
if (pos != std::string::npos) {
fpath = fname_inp.substr(0, pos + 1);
}
// export as [inp path]/ggml-model-[ftype].bin
fname_out = fpath + "ggml-model-" + ftype_str + ".bin";
arg_idx++;
}
else {
// argv[2] is the output path
fname_out = argv[arg_idx];
arg_idx++;
if (argc <= arg_idx) {
fprintf(stderr, "%s: missing ftype\n", __func__);
return 1;
}
// argv[3] is the ftype
if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
return 1;
}
arg_idx++;
}
// parse nthreads
if (argc > arg_idx) {
try {
nthread = std::stoi(argv[arg_idx]);
}
catch (const std::exception & e) {
fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
return 1; return 1;
} }
ftype = it->second;
} else { } else {
ftype = (enum llama_ftype)atoi(argv[3]); nthread = 0;
} }
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
int nthread = argc > 4 ? atoi(argv[4]) : 0; fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
if (nthread > 0) {
fprintf(stderr, " using %d threads", nthread);
}
fprintf(stderr, "\n");
const int64_t t_main_start_us = ggml_time_us(); const int64_t t_main_start_us = ggml_time_us();

View file

@ -14,6 +14,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <stdexcept>
#ifdef __has_include #ifdef __has_include
#if __has_include(<unistd.h>) #if __has_include(<unistd.h>)
@ -74,7 +75,7 @@ struct llama_file {
llama_file(const char * fname, const char * mode) { llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode); fp = std::fopen(fname, mode);
if (fp == NULL) { if (fp == NULL) {
throw format("failed to open %s: %s", fname, std::strerror(errno)); throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
} }
seek(0, SEEK_END); seek(0, SEEK_END);
size = tell(); size = tell();
@ -107,10 +108,10 @@ struct llama_file {
errno = 0; errno = 0;
std::size_t ret = std::fread(ptr, size, 1, fp); std::size_t ret = std::fread(ptr, size, 1, fp);
if (ferror(fp)) { if (ferror(fp)) {
throw format("read error: %s", strerror(errno)); throw std::runtime_error(format("read error: %s", strerror(errno)));
} }
if (ret != 1) { if (ret != 1) {
throw std::string("unexpectedly reached end of file"); throw std::runtime_error(std::string("unexpectedly reached end of file"));
} }
} }
@ -133,7 +134,7 @@ struct llama_file {
errno = 0; errno = 0;
size_t ret = std::fwrite(ptr, size, 1, fp); size_t ret = std::fwrite(ptr, size, 1, fp);
if (ret != 1) { if (ret != 1) {
throw format("write error: %s", strerror(errno)); throw std::runtime_error(format("write error: %s", strerror(errno)));
} }
} }
@ -180,7 +181,7 @@ struct llama_mmap {
#endif #endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) { if (addr == MAP_FAILED) {
throw format("mmap failed: %s", strerror(errno)); throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
} }
if (prefetch) { if (prefetch) {
@ -207,7 +208,7 @@ struct llama_mmap {
DWORD error = GetLastError(); DWORD error = GetLastError();
if (hMapping == NULL) { if (hMapping == NULL) {
throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()); throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
} }
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
@ -215,7 +216,7 @@ struct llama_mmap {
CloseHandle(hMapping); CloseHandle(hMapping);
if (addr == NULL) { if (addr == NULL) {
throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()); throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
} }
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 #if _WIN32_WINNT >= _WIN32_WINNT_WIN8
@ -245,7 +246,7 @@ struct llama_mmap {
llama_mmap(struct llama_file *, bool prefetch = true) { llama_mmap(struct llama_file *, bool prefetch = true) {
(void)prefetch; (void)prefetch;
throw std::string("mmap not supported"); throw std::runtime_error(std::string("mmap not supported"));
} }
#endif #endif
}; };