Merge remote-tracking branch 'origin/master' into sl/backend-registry-2
This commit is contained in:
commit
2a60833a01
9 changed files with 289 additions and 546 deletions
13
Makefile
13
Makefile
|
@ -5,7 +5,6 @@ BUILD_TARGETS = \
|
|||
llama-batched \
|
||||
llama-batched-bench \
|
||||
llama-bench \
|
||||
llama-benchmark-matmult \
|
||||
llama-cli \
|
||||
llama-convert-llama2c-to-ggml \
|
||||
llama-embedding \
|
||||
|
@ -68,7 +67,7 @@ TEST_TARGETS = \
|
|||
# Legacy build targets that were renamed in #7809, but should still be removed when the project is cleaned
|
||||
LEGACY_TARGETS_CLEAN = main quantize quantize-stats perplexity imatrix embedding vdot q8dot convert-llama2c-to-ggml \
|
||||
simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli baby-llama \
|
||||
retrieval speculative infill tokenize benchmark-matmult parallel export-lora lookahead lookup passkey gritlm
|
||||
retrieval speculative infill tokenize parallel export-lora lookahead lookup passkey gritlm
|
||||
|
||||
# Legacy build targets that were renamed in #7809, but we want to build binaries that for them that output a deprecation warning if people try to use them.
|
||||
# We don't want to clutter things too much, so we only build replacements for the most commonly used binaries.
|
||||
|
@ -1524,16 +1523,6 @@ common/build-info.o: common/build-info.cpp
|
|||
|
||||
tests: $(TEST_TARGETS)
|
||||
|
||||
llama-benchmark-matmult: examples/benchmark/benchmark-matmult.cpp \
|
||||
$(OBJ_GGML) common/build-info.o
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
run-benchmark-matmult: llama-benchmark-matmult
|
||||
./$@
|
||||
|
||||
.PHONY: run-benchmark-matmult swift
|
||||
|
||||
tests/test-arg-parser: tests/test-arg-parser.cpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
|
|
|
@ -92,6 +92,7 @@ Typically finetunes of the base models below are supported as well.
|
|||
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
|
||||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
|
||||
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
|
||||
|
||||
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
### Llama.cpp + SYCL
|
||||
|
||||
The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*).
|
||||
The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it also supports other vendor GPUs: Nvidia and AMD.
|
||||
|
||||
## Recommended Release
|
||||
|
||||
|
@ -111,10 +111,18 @@ SYCL backend supports Intel GPU Family:
|
|||
|
||||
**Verified devices**
|
||||
|
||||
| Nvidia GPU | Status | Verified Model |
|
||||
|--------------------------|---------|----------------|
|
||||
| Ampere Series | Support | A100, A4000 |
|
||||
| Ampere Series *(Mobile)* | Support | RTX 40 Series |
|
||||
| Nvidia GPU | Status | Verified Model |
|
||||
|--------------------------|-----------|----------------|
|
||||
| Ampere Series | Supported | A100, A4000 |
|
||||
| Ampere Series *(Mobile)* | Supported | RTX 40 Series |
|
||||
|
||||
| AMD GPU | Status | Verified Model |
|
||||
|--------------------------|--------------|----------------|
|
||||
| Radeon Pro | Experimental | W6800 |
|
||||
| Radeon RX | Experimental | 6700 XT |
|
||||
|
||||
Note: AMD GPU support is highly experimental and is incompatible with F16.
|
||||
Additionally, it only supports GPUs with a sub_group_size (warp size) of 32.
|
||||
|
||||
## Docker
|
||||
The docker build option is currently limited to *intel GPU* targets.
|
||||
|
@ -186,6 +194,10 @@ Platform #0: Intel(R) OpenCL HD Graphics
|
|||
|
||||
In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed.
|
||||
|
||||
- **AMD GPU**
|
||||
|
||||
To target AMD GPUs with SYCL, the ROCm stack must be installed first.
|
||||
|
||||
2. **Install Intel® oneAPI Base toolkit**
|
||||
|
||||
- **For Intel GPU**
|
||||
|
@ -212,6 +224,19 @@ cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENAB
|
|||
cmake --build buildWithCublas --config Release
|
||||
```
|
||||
|
||||
- **Adding support to AMD GPUs**
|
||||
|
||||
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
|
||||
|
||||
**oneMKL for rocBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* doesn't contain the rocBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *rocBLAS* backend enabled is thus required to run it on AMD GPUs.
|
||||
|
||||
```sh
|
||||
git clone https://github.com/oneapi-src/oneMKL
|
||||
cd oneMKL
|
||||
# Find your HIPTARGET with rocminfo, under the key 'Name:'
|
||||
cmake -B buildWithrocBLAS -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_ROCBLAS_BACKEND=ON -DHIPTARGETS=${HIPTARGET} -DTARGET_DOMAINS=blas
|
||||
cmake --build buildWithrocBLAS --config Release
|
||||
```
|
||||
|
||||
3. **Verify installation and environment**
|
||||
|
||||
|
@ -223,22 +248,32 @@ sycl-ls
|
|||
|
||||
- **Intel GPU**
|
||||
|
||||
When targeting an intel GPU, the user should expect one or more level-zero devices among the available SYCL devices. Please make sure that at least one GPU is present, for instance [`ext_oneapi_level_zero:gpu:0`] in the sample output below:
|
||||
When targeting an intel GPU, the user should expect one or more level-zero devices among the available SYCL devices. Please make sure that at least one GPU is present, for instance [`level_zero:gpu`] in the sample output below:
|
||||
|
||||
```
|
||||
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
|
||||
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
|
||||
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50]
|
||||
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918]
|
||||
[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
|
||||
[opencl:cpu][opencl:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
|
||||
[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50]
|
||||
[level_zero:gpu][level_zero:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918]
|
||||
```
|
||||
|
||||
- **Nvidia GPU**
|
||||
|
||||
Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`ext_oneapi_cuda:gpu`] as bellow:
|
||||
Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`cuda:gpu`] as below:
|
||||
|
||||
```
|
||||
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix]
|
||||
[opencl:cpu:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix]
|
||||
[ext_oneapi_cuda:gpu:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.2]
|
||||
[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix]
|
||||
[opencl:cpu][opencl:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix]
|
||||
[cuda:gpu][cuda:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.5]
|
||||
```
|
||||
|
||||
- **AMD GPU**
|
||||
|
||||
For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]:
|
||||
|
||||
```
|
||||
[opencl:cpu][opencl:0] Intel(R) OpenCL, 12th Gen Intel(R) Core(TM) i9-12900K OpenCL 3.0 (Build 0) [2024.18.6.0.02_160000]
|
||||
[hip:gpu][hip:0] AMD HIP BACKEND, AMD Radeon PRO W6800 gfx1030 [HIP 60140.9]
|
||||
```
|
||||
|
||||
### II. Build llama.cpp
|
||||
|
@ -266,6 +301,7 @@ cmake --build build --config Release -j -v
|
|||
```
|
||||
|
||||
#### Nvidia GPU
|
||||
|
||||
```sh
|
||||
# Export relevant ENV variables
|
||||
export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LD_LIBRARY_PATH
|
||||
|
@ -283,7 +319,25 @@ cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -
|
|||
|
||||
# build all binary
|
||||
cmake --build build --config Release -j -v
|
||||
```
|
||||
|
||||
#### AMD GPU
|
||||
|
||||
```sh
|
||||
# Export relevant ENV variables
|
||||
export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LD_LIBRARY_PATH
|
||||
export LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LIBRARY_PATH
|
||||
export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithrocBLAS/include:$CPLUS_INCLUDE_DIR
|
||||
|
||||
# Build LLAMA with rocBLAS acceleration through SYCL
|
||||
|
||||
## AMD
|
||||
# Use FP32, FP16 is not supported
|
||||
# Find your GGML_SYCL_HIP_TARGET with rocminfo, under the key 'Name:'
|
||||
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=AMD -DGGML_SYCL_HIP_TARGET=${GGML_SYCL_HIP_TARGET} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||
|
||||
# build all binary
|
||||
cmake --build build --config Release -j -v
|
||||
```
|
||||
|
||||
### III. Run the inference
|
||||
|
@ -586,11 +640,11 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||
|
||||
#### Build
|
||||
|
||||
| Name | Value | Function |
|
||||
|--------------------|-----------------------------------|---------------------------------------------|
|
||||
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
|
||||
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA | Set the SYCL target device type. |
|
||||
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
|
||||
| Name | Value | Function |
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.<br>FP32 path - recommended for better perforemance than FP16 on quantized model|
|
||||
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. |
|
||||
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. |
|
||||
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
|
||||
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ else()
|
|||
add_subdirectory(baby-llama)
|
||||
add_subdirectory(batched-bench)
|
||||
add_subdirectory(batched)
|
||||
add_subdirectory(benchmark)
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
add_subdirectory(embedding)
|
||||
add_subdirectory(eval-callback)
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
set(TARGET llama-bench-matmult)
|
||||
add_executable(${TARGET} benchmark-matmult.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_include_directories(${TARGET} PRIVATE ../../common)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
@ -1,275 +0,0 @@
|
|||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <locale.h>
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include <cinttypes>
|
||||
#include <unordered_map>
|
||||
#include <queue>
|
||||
#include <string.h>
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
}
|
||||
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
static float tensor_sum_elements(const ggml_tensor * tensor) {
|
||||
double sum = 0;
|
||||
if (tensor->type == GGML_TYPE_F32) {
|
||||
for (int j = 0; j < tensor->ne[1]; j++) {
|
||||
for (int k = 0; k < tensor->ne[0]; k++) {
|
||||
sum += ((float *) tensor->data)[j*tensor->ne[0] + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static void tensor_dump(const ggml_tensor * tensor, const char * name) {
|
||||
printf("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi) - ", name,
|
||||
tensor->type, ggml_type_name(tensor->type),
|
||||
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->nb[0], tensor->nb[1], tensor->nb[2]);
|
||||
float sum = tensor_sum_elements(tensor);
|
||||
printf("Sum of tensor %s is %6.2f\n", name, sum);
|
||||
}
|
||||
|
||||
#define TENSOR_DUMP(tensor) tensor_dump(tensor, #tensor)
|
||||
|
||||
struct benchmark_params_struct {
|
||||
int n_threads = 1;
|
||||
int32_t n_iterations = 10;
|
||||
};
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv, struct benchmark_params_struct params) {
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -i N, --iter N number of iterations to use during computation (default: %d)\n", params.n_iterations);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
struct benchmark_params_struct benchmark_params;
|
||||
|
||||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
|
||||
if (arg == "-t" || arg == "--threads") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
benchmark_params.n_threads = std::stoi(argv[i]);
|
||||
} else if (arg == "-i" || arg == "--iter") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
benchmark_params.n_iterations = std::stoi(argv[i]);
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv, benchmark_params);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
if (invalid_param) {
|
||||
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
||||
print_usage(argc, argv, benchmark_params);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
print_build_info();
|
||||
printf("Starting Test\n");
|
||||
|
||||
// create the ggml context
|
||||
struct ggml_context * ctx;
|
||||
//const int sizex = 4096;
|
||||
//const int sizey = 11008;
|
||||
|
||||
#undef VERBOSE_DEBUGGING
|
||||
#ifndef VERBOSE_DEBUGGING
|
||||
const int sizey = 4096;
|
||||
const int sizex = 11008;
|
||||
const int sizez = 128;
|
||||
#else
|
||||
/* Working - let's increase size */
|
||||
const int sizey = 1;
|
||||
const int sizex = (8*32);
|
||||
const int sizez = 1;
|
||||
|
||||
/*const int sizey = 1;
|
||||
const int sizex = 3*(8*32);
|
||||
const int sizez = 1;*/
|
||||
#endif
|
||||
|
||||
//printf("Memsize required = %i\n", sizex*sizex);
|
||||
|
||||
// TODO: perform the bench for all types or for a user specified type
|
||||
const ggml_type qtype = GGML_TYPE_Q4_1;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizez);
|
||||
ctx_size += ggml_row_size(qtype, sizex*sizey);
|
||||
ctx_size += ggml_row_size(qtype, sizex*sizey);
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
|
||||
ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
|
||||
ctx_size += 1024*1024*16;
|
||||
|
||||
printf("Allocating Memory of size %zi bytes, %zi MB\n",ctx_size, (ctx_size/1024/1024));
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/* no_alloc =*/ 0
|
||||
};
|
||||
|
||||
ctx = ggml_init(params);
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
printf("Creating new tensors\n");
|
||||
// printf("Creating new tensor m1\n");
|
||||
struct ggml_tensor * m11 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, sizex, sizey);
|
||||
ggml_set_f32(m11, 1.0f);
|
||||
|
||||
// printf("Creating new tensor m1\n");
|
||||
struct ggml_tensor * m12 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, sizex, sizey);
|
||||
ggml_set_f32(m12, 1.5f);
|
||||
|
||||
// printf("Creating new tensor m2\n");
|
||||
struct ggml_tensor * m2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, sizex, sizez);
|
||||
ggml_set_f32(m2, 2.0f);
|
||||
|
||||
printf("\n------ Test 1 - Matrix Mult via F32 code\n");
|
||||
// printf("Creating new tensor m11xm2\n");
|
||||
struct ggml_tensor * m11xm2 = ggml_mul_mat(ctx, m11, m2);
|
||||
|
||||
// printf("Creating compute graph\n");
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx);
|
||||
ggml_build_forward_expand(gf, m11xm2);
|
||||
|
||||
printf("n_threads=%i\n", benchmark_params.n_threads);
|
||||
|
||||
TENSOR_DUMP(m11);
|
||||
TENSOR_DUMP(m2);
|
||||
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
ggml_graph_compute_helper(work_buffer, gf, benchmark_params.n_threads);
|
||||
|
||||
TENSOR_DUMP(ggml_graph_node(gf, 0));
|
||||
|
||||
printf("\n------ Test 2 - Matrix Mult via %s code\n", ggml_type_name(qtype));
|
||||
|
||||
int32_t nelements = sizex*sizey;
|
||||
|
||||
// Set up a the benchmark matrices
|
||||
// printf("Creating new tensor q11 & Running quantize\n");
|
||||
struct ggml_tensor * q11 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey);
|
||||
ggml_quantize_chunk(qtype, (const float *) m11->data, q11->data, 0, nelements/m11->ne[0], m11->ne[0], nullptr);
|
||||
|
||||
// Set up a the compute graph
|
||||
// printf("Creating new tensor q31\n");
|
||||
struct ggml_tensor * q31 = ggml_mul_mat(ctx, q11, m2);
|
||||
|
||||
// printf("Creating compute graph\n");
|
||||
struct ggml_cgraph * gf31 = ggml_new_graph(ctx);
|
||||
ggml_build_forward_expand(gf31, q31);
|
||||
|
||||
// Set up a second graph computation to make sure we override the CPU cache lines
|
||||
// printf("Creating new tensor q12 & Running quantize\n");
|
||||
struct ggml_tensor * q12 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey);
|
||||
ggml_quantize_chunk(qtype, (const float *) m12->data, q12->data, 0, nelements/m12->ne[0], m12->ne[0], nullptr);
|
||||
|
||||
// printf("Creating new tensor q32\n");
|
||||
struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2);
|
||||
|
||||
//printf("Creating compute graph\n");
|
||||
struct ggml_cgraph * gf32 = ggml_new_graph(ctx);
|
||||
ggml_build_forward_expand(gf32, q32);
|
||||
printf("n_threads=%i\n", benchmark_params.n_threads);
|
||||
|
||||
const int dimx = sizex;
|
||||
const int dimy = sizey;
|
||||
const int dimz = sizez;
|
||||
long long int flops_per_dot_product = dimy + dimy;
|
||||
long long int flops_per_matrix = flops_per_dot_product * dimx * dimz; ;
|
||||
printf("Matrix Multiplication of (%i,%i,%i) x (%i,%i,%i) - about %6.2f gFLOPS\n\n", sizex, sizey, 1, sizex, sizez, 1, 1.0f*flops_per_matrix / 1000 / 1000 / 1000);
|
||||
|
||||
|
||||
// Let's use the F32 result from above as a reference for the quantized multiplication
|
||||
float sum_of_F32_reference = tensor_sum_elements(ggml_graph_node(gf, 0));
|
||||
|
||||
printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; gigaFLOPS\n");
|
||||
printf("=====================================================================================\n");
|
||||
|
||||
double gflops_sum = 0;
|
||||
for (int i=0;i<benchmark_params.n_iterations ;i++) {
|
||||
|
||||
long long int start = ggml_time_us();
|
||||
//printf("Running ggml_graph_compute\n");
|
||||
ggml_graph_compute_helper(work_buffer, gf31, benchmark_params.n_threads);
|
||||
|
||||
long long int stop = ggml_time_us();
|
||||
long long int usec = stop-start;
|
||||
double gflops = (double)(flops_per_matrix)/usec/1000.0;
|
||||
gflops_sum += gflops;
|
||||
printf("%9i;%8i;%6i;%6i;%6i;%15lli;%18lli;%10.2f\n",
|
||||
i,
|
||||
benchmark_params.n_threads,
|
||||
sizex, sizey, sizez, flops_per_matrix,
|
||||
usec,gflops);
|
||||
|
||||
#ifdef VERBOSE_DEBUGGING
|
||||
TENSOR_DUMP("res",gf31.nodes[0])
|
||||
#endif
|
||||
|
||||
// Check that the matrix multiplication result is in the right ballpark
|
||||
// We cannot use the exact value from the F32 multiplication because the quantizuation will be slightly different
|
||||
float sum_of_Q4_result = tensor_sum_elements(ggml_graph_node(gf31, 0));
|
||||
float delta = std::abs(sum_of_Q4_result - sum_of_F32_reference);
|
||||
float allowed_delta = (sum_of_F32_reference) / 1000 / 1000; // Let's accept an epsilon of 10^-6
|
||||
|
||||
if (delta > allowed_delta) {
|
||||
printf("\nABORT - ERROR in Matrix Multiplication result - expected %6.2f, got %6.2f (delta %6.2f > allowed_delta %6.2f)\n",
|
||||
sum_of_F32_reference,
|
||||
sum_of_Q4_result,
|
||||
delta,
|
||||
allowed_delta
|
||||
);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// Running a different graph computation to make sure we override the CPU cache lines
|
||||
ggml_graph_compute_helper(work_buffer, gf32, benchmark_params.n_threads);
|
||||
}
|
||||
printf("\n");
|
||||
printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));
|
||||
printf("=====================================================================================\n");
|
||||
}
|
|
@ -22,12 +22,20 @@
|
|||
#endif
|
||||
|
||||
enum split_operation : uint8_t {
|
||||
SPLIT_OP_SPLIT,
|
||||
SPLIT_OP_MERGE,
|
||||
OP_NONE,
|
||||
OP_SPLIT,
|
||||
OP_MERGE,
|
||||
};
|
||||
|
||||
enum split_mode : uint8_t {
|
||||
MODE_NONE,
|
||||
MODE_TENSOR,
|
||||
MODE_SIZE,
|
||||
};
|
||||
|
||||
struct split_params {
|
||||
split_operation operation = SPLIT_OP_SPLIT;
|
||||
split_operation operation = OP_NONE;
|
||||
split_mode mode = MODE_NONE;
|
||||
size_t n_bytes_split = 0;
|
||||
int n_split_tensors = 128;
|
||||
std::string input;
|
||||
|
@ -87,59 +95,52 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
|
|||
}
|
||||
|
||||
bool arg_found = false;
|
||||
bool is_op_set = false;
|
||||
bool is_mode_set = false;
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
split_print_usage(argv[0]);
|
||||
exit(0);
|
||||
}
|
||||
if (arg == "--version") {
|
||||
} else if (arg == "--version") {
|
||||
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
||||
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
||||
exit(0);
|
||||
}
|
||||
if (arg == "--dry-run") {
|
||||
} else if (arg == "--dry-run") {
|
||||
arg_found = true;
|
||||
params.dry_run = true;
|
||||
}
|
||||
if (arg == "--no-tensor-first-split") {
|
||||
} else if (arg == "--no-tensor-first-split") {
|
||||
arg_found = true;
|
||||
params.no_tensor_first_split = true;
|
||||
}
|
||||
|
||||
if (is_op_set) {
|
||||
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
|
||||
}
|
||||
if (arg == "--merge") {
|
||||
} else if (arg == "--merge") {
|
||||
arg_found = true;
|
||||
is_op_set = true;
|
||||
params.operation = SPLIT_OP_MERGE;
|
||||
}
|
||||
if (arg == "--split") {
|
||||
if (params.operation != OP_NONE && params.operation != OP_MERGE) {
|
||||
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
|
||||
}
|
||||
params.operation = OP_MERGE;
|
||||
} else if (arg == "--split") {
|
||||
arg_found = true;
|
||||
is_op_set = true;
|
||||
params.operation = SPLIT_OP_SPLIT;
|
||||
}
|
||||
|
||||
if (is_mode_set) {
|
||||
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
|
||||
}
|
||||
if (arg == "--split-max-tensors") {
|
||||
if (params.operation != OP_NONE && params.operation != OP_SPLIT) {
|
||||
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
|
||||
}
|
||||
params.operation = OP_SPLIT;
|
||||
} else if (arg == "--split-max-tensors") {
|
||||
if (++arg_idx >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
arg_found = true;
|
||||
is_mode_set = true;
|
||||
if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) {
|
||||
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
|
||||
}
|
||||
params.mode = MODE_TENSOR;
|
||||
params.n_split_tensors = atoi(argv[arg_idx]);
|
||||
}
|
||||
if (arg == "--split-max-size") {
|
||||
} else if (arg == "--split-max-size") {
|
||||
if (++arg_idx >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
arg_found = true;
|
||||
is_mode_set = true;
|
||||
if (params.mode != MODE_NONE && params.mode != MODE_SIZE) {
|
||||
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
|
||||
}
|
||||
params.mode = MODE_SIZE;
|
||||
params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
|
||||
}
|
||||
|
||||
|
@ -148,6 +149,15 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
|
|||
}
|
||||
}
|
||||
|
||||
// the operation is split if not specified
|
||||
if (params.operation == OP_NONE) {
|
||||
params.operation = OP_SPLIT;
|
||||
}
|
||||
// the split mode is by tensor if not specified
|
||||
if (params.mode == MODE_NONE) {
|
||||
params.mode = MODE_TENSOR;
|
||||
}
|
||||
|
||||
if (invalid_param) {
|
||||
throw std::invalid_argument("error: invalid parameter for argument: " + arg);
|
||||
}
|
||||
|
@ -265,13 +275,15 @@ struct split_strategy {
|
|||
}
|
||||
|
||||
bool should_split(int i_tensor, size_t next_size) {
|
||||
if (params.n_bytes_split > 0) {
|
||||
if (params.mode == MODE_SIZE) {
|
||||
// split by max size per file
|
||||
return next_size > params.n_bytes_split;
|
||||
} else {
|
||||
} else if (params.mode == MODE_TENSOR) {
|
||||
// split by number of tensors per file
|
||||
return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
|
||||
}
|
||||
// should never happen
|
||||
GGML_ABORT("invalid mode");
|
||||
}
|
||||
|
||||
void print_info() {
|
||||
|
@ -559,9 +571,9 @@ int main(int argc, const char ** argv) {
|
|||
split_params_parse(argc, argv, params);
|
||||
|
||||
switch (params.operation) {
|
||||
case SPLIT_OP_SPLIT: gguf_split(params);
|
||||
case OP_SPLIT: gguf_split(params);
|
||||
break;
|
||||
case SPLIT_OP_MERGE: gguf_merge(params);
|
||||
case OP_MERGE: gguf_merge(params);
|
||||
break;
|
||||
default: split_print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
|
|
|
@ -511,8 +511,8 @@ if (GGML_HIPBLAS)
|
|||
endif()
|
||||
|
||||
if (GGML_SYCL)
|
||||
if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA)$")
|
||||
message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL or NVIDIA")
|
||||
if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
|
||||
message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
|
||||
endif()
|
||||
|
||||
check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
|
||||
|
@ -532,6 +532,9 @@ if (GGML_SYCL)
|
|||
list(APPEND GGML_CDEF_PUBLIC GGML_USE_SYCL)
|
||||
|
||||
if (GGML_SYCL_F16)
|
||||
if (GGML_SYCL_TARGET STREQUAL "AMD")
|
||||
message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
|
||||
endif()
|
||||
add_compile_definitions(GGML_SYCL_F16)
|
||||
endif()
|
||||
|
||||
|
@ -543,6 +546,12 @@ if (GGML_SYCL)
|
|||
|
||||
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
||||
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
||||
# INFO: Allowed Sub_group_sizes are not consistent through all
|
||||
# hip targets. For example, 64 is used for certain models, but the backend
|
||||
# does not support it.
|
||||
# Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
||||
else()
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
|
||||
endif()
|
||||
|
@ -576,6 +585,12 @@ if (GGML_SYCL)
|
|||
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
|
||||
list(APPEND GGML_EXTRA_LIBS_PRIVATE sycl pthread m dl onemkl)
|
||||
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
||||
if (GGML_SYCL_HIP_TARGET STREQUAL "")
|
||||
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_HIP_TARGET has not been set.")
|
||||
endif()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${GGML_SYCL_HIP_TARGET}")
|
||||
list(APPEND GGML_EXTRA_LIBS_PRIVATE sycl pthread m dl onemkl)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -433,16 +433,6 @@ struct vk_context_struct {
|
|||
typedef std::shared_ptr<vk_context_struct> vk_context;
|
||||
typedef std::weak_ptr<vk_context_struct> vk_context_ref;
|
||||
|
||||
struct ggml_tensor_extra_gpu {
|
||||
vk_buffer_ref buffer_gpu;
|
||||
uint64_t offset;
|
||||
|
||||
void reset() {
|
||||
buffer_gpu.reset();
|
||||
offset = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_vk_garbage_collector {
|
||||
std::vector<vk_semaphore> tl_semaphores;
|
||||
std::vector<vk_semaphore> semaphores;
|
||||
|
@ -553,6 +543,31 @@ struct ggml_backend_vk_context {
|
|||
std::vector<vk_context_ref> tensor_ctxs;
|
||||
};
|
||||
|
||||
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||
|
||||
static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
|
||||
if (tensor->view_src) {
|
||||
return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
|
||||
}
|
||||
return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
|
||||
}
|
||||
|
||||
struct ggml_backend_vk_buffer_context {
|
||||
vk_device_ref device;
|
||||
vk_buffer dev_buffer;
|
||||
std::string name;
|
||||
|
||||
ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
|
||||
device(device),
|
||||
dev_buffer(dev_buffer),
|
||||
name(name) {
|
||||
}
|
||||
|
||||
~ggml_backend_vk_buffer_context() {
|
||||
ggml_vk_destroy_buffer(dev_buffer);
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||
void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
|
||||
std::lock_guard<std::mutex> guard(log_mutex);
|
||||
|
@ -3077,9 +3092,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
const uint64_t r2 = ne12 / ne02;
|
||||
const uint64_t r3 = ne13 / ne03;
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
|
||||
vk_buffer d_Qx;
|
||||
size_t qx_buf_offset = 0;
|
||||
|
@ -3181,8 +3196,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
return;
|
||||
}
|
||||
|
||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03);
|
||||
vk_buffer d_X;
|
||||
|
@ -3190,13 +3205,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
vk_buffer d_Y;
|
||||
uint64_t y_buf_offset = 0;
|
||||
if (!src0_uma) {
|
||||
d_Qx = extra_src0->buffer_gpu.lock();
|
||||
qx_buf_offset = extra_src0->offset + src0->view_offs;
|
||||
d_Qx = src0_buf_ctx->dev_buffer;
|
||||
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
}
|
||||
if (!src1_uma) {
|
||||
d_Qy = extra_src1->buffer_gpu.lock();
|
||||
qy_buf_offset = extra_src1->offset + src1->view_offs;
|
||||
d_Qy = src1_buf_ctx->dev_buffer;
|
||||
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Qy != nullptr);
|
||||
}
|
||||
if (qx_needs_dequant) {
|
||||
|
@ -3277,9 +3292,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
const uint64_t r2 = ne12 / ne02;
|
||||
const uint64_t r3 = ne13 / ne03;
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
|
||||
vk_buffer d_Qx;
|
||||
size_t qx_buf_offset = 0;
|
||||
|
@ -3358,21 +3373,21 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
return;
|
||||
}
|
||||
|
||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
vk_buffer d_X;
|
||||
uint64_t x_buf_offset = 0;
|
||||
vk_buffer d_Y;
|
||||
uint64_t y_buf_offset = 0;
|
||||
if(!src0_uma) {
|
||||
d_Qx = extra_src0->buffer_gpu.lock();
|
||||
qx_buf_offset = extra_src0->offset + src0->view_offs;
|
||||
d_Qx = src0_buf_ctx->dev_buffer;
|
||||
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
}
|
||||
if(!src1_uma) {
|
||||
d_Qy = extra_src1->buffer_gpu.lock();
|
||||
qy_buf_offset = extra_src1->offset + src1->view_offs;
|
||||
d_Qy = src1_buf_ctx->dev_buffer;
|
||||
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Qy != nullptr);
|
||||
}
|
||||
if (qx_needs_dequant) {
|
||||
|
@ -3455,9 +3470,9 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
|
|||
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
|
||||
vk_buffer d_Qy;
|
||||
size_t qy_buf_offset = 0;
|
||||
|
@ -3483,15 +3498,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
|
|||
return;
|
||||
}
|
||||
|
||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
vk_buffer d_Qx = extra_src0->buffer_gpu.lock();
|
||||
const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs;
|
||||
vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
|
||||
const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
if (!src1_uma) {
|
||||
d_Qy = extra_src1->buffer_gpu.lock();
|
||||
qy_buf_offset = extra_src1->offset + src1->view_offs;
|
||||
d_Qy = src1_buf_ctx->dev_buffer;
|
||||
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
}
|
||||
|
||||
|
@ -3533,9 +3548,9 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
|
||||
vk_buffer d_Qy = nullptr;
|
||||
size_t qy_buf_offset = 0;
|
||||
|
@ -3562,15 +3577,15 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||
return;
|
||||
}
|
||||
|
||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
vk_buffer d_Qx = extra_src0->buffer_gpu.lock();
|
||||
const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs;
|
||||
vk_buffer d_Qx = src0_buf_ctx->dev_buffer;
|
||||
const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
if (!src1_uma) {
|
||||
d_Qy = extra_src1->buffer_gpu.lock();
|
||||
qy_buf_offset = extra_src1->offset + src1->view_offs;
|
||||
d_Qy = src1_buf_ctx->dev_buffer;
|
||||
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
}
|
||||
|
||||
|
@ -3632,10 +3647,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
|
||||
const uint64_t n_as = ne02;
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
|
||||
|
||||
vk_buffer d_Qx;
|
||||
size_t qx_buf_offset = 0;
|
||||
|
@ -3732,26 +3747,26 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
return;
|
||||
}
|
||||
|
||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
vk_buffer d_X;
|
||||
uint64_t x_buf_offset = 0;
|
||||
vk_buffer d_Y;
|
||||
uint64_t y_buf_offset = 0;
|
||||
if (!src0_uma) {
|
||||
d_Qx = extra_src0->buffer_gpu.lock();
|
||||
qx_buf_offset = extra_src0->offset + src0->view_offs;
|
||||
d_Qx = src0_buf_ctx->dev_buffer;
|
||||
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
}
|
||||
if (!src1_uma) {
|
||||
d_Qy = extra_src1->buffer_gpu.lock();
|
||||
qy_buf_offset = extra_src1->offset + src1->view_offs;
|
||||
d_Qy = src1_buf_ctx->dev_buffer;
|
||||
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Qy != nullptr);
|
||||
}
|
||||
if (!ids_uma) {
|
||||
d_ids = extra_ids->buffer_gpu.lock();
|
||||
ids_buf_offset = extra_ids->offset + ids->view_offs;
|
||||
d_ids = ids_buf_ctx->dev_buffer;
|
||||
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
|
||||
GGML_ASSERT(d_ids != nullptr);
|
||||
}
|
||||
if (qx_needs_dequant) {
|
||||
|
@ -3837,10 +3852,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
const uint64_t ne22 = dst->ne[2];
|
||||
const uint64_t ne23 = dst->ne[3];
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
|
||||
|
||||
vk_buffer d_Qx;
|
||||
size_t qx_buf_offset = 0;
|
||||
|
@ -3925,26 +3940,26 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
return;
|
||||
}
|
||||
|
||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||
const uint64_t d_buf_offset = extra->offset + dst->view_offs;
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
vk_buffer d_X;
|
||||
uint64_t x_buf_offset = 0;
|
||||
vk_buffer d_Y;
|
||||
uint64_t y_buf_offset = 0;
|
||||
if(!src0_uma) {
|
||||
d_Qx = extra_src0->buffer_gpu.lock();
|
||||
qx_buf_offset = extra_src0->offset + src0->view_offs;
|
||||
d_Qx = src0_buf_ctx->dev_buffer;
|
||||
qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_Qx != nullptr);
|
||||
}
|
||||
if(!src1_uma) {
|
||||
d_Qy = extra_src1->buffer_gpu.lock();
|
||||
qy_buf_offset = extra_src1->offset + src1->view_offs;
|
||||
d_Qy = src1_buf_ctx->dev_buffer;
|
||||
qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Qy != nullptr);
|
||||
}
|
||||
if(!ids_uma) {
|
||||
d_ids = extra_ids->buffer_gpu.lock();
|
||||
ids_buf_offset = extra_ids->offset + ids->view_offs;
|
||||
d_ids = ids_buf_ctx->dev_buffer;
|
||||
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
|
||||
GGML_ASSERT(d_ids != nullptr);
|
||||
}
|
||||
if (qx_needs_dequant) {
|
||||
|
@ -4251,7 +4266,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
|
||||
GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
|
||||
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
|
||||
GGML_ASSERT(dst->extra != nullptr);
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
const uint64_t ne00 = src0->ne[0];
|
||||
const uint64_t ne01 = src0->ne[1];
|
||||
const uint64_t ne02 = src0->ne[2];
|
||||
|
@ -4297,10 +4312,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
|
||||
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
||||
ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
|
||||
ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
|
||||
|
||||
vk_buffer d_X = nullptr;
|
||||
size_t x_buf_offset = 0;
|
||||
|
@ -4331,7 +4346,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
|
||||
uint64_t d_sz = ggml_type_size(dst->type) * ned;
|
||||
|
||||
vk_buffer d_D = extra->buffer_gpu.lock();
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
|
||||
// Workaround for tiny tensor inputs on ROPE
|
||||
if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
|
||||
|
@ -4339,21 +4354,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
}
|
||||
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
uint64_t d_buf_offset = ((extra->offset + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
|
||||
GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT
|
||||
uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
|
||||
GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY); // NOLINT
|
||||
if(!src0_uma) {
|
||||
d_X = extra_src0->buffer_gpu.lock();
|
||||
x_buf_offset = extra_src0->offset + src0->view_offs;
|
||||
d_X = src0_buf_ctx->dev_buffer;
|
||||
x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_X != nullptr);
|
||||
}
|
||||
if (use_src1 && !src1_uma) {
|
||||
d_Y = extra_src1->buffer_gpu.lock();
|
||||
y_buf_offset = extra_src1->offset + src1->view_offs;
|
||||
d_Y = src1_buf_ctx->dev_buffer;
|
||||
y_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Y != nullptr);
|
||||
}
|
||||
if (use_src2 && !src2_uma) {
|
||||
d_Z = extra_src2->buffer_gpu.lock();
|
||||
z_buf_offset = extra_src2->offset + src2->view_offs;
|
||||
d_Z = src2_buf_ctx->dev_buffer;
|
||||
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
|
||||
GGML_ASSERT(d_Z != nullptr);
|
||||
}
|
||||
|
||||
|
@ -4532,11 +4547,10 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
}
|
||||
|
||||
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
||||
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
|
@ -4725,10 +4739,9 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
}
|
||||
|
||||
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
||||
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
|
@ -5536,14 +5549,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|||
}
|
||||
#endif
|
||||
|
||||
static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_extra(" << tensor << " (" << tensor->name << ", " << ggml_op_name(tensor->op) << "))");
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
|
||||
extra->reset();
|
||||
tensor->extra = extra;
|
||||
return extra;
|
||||
}
|
||||
|
||||
static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
||||
#if defined(GGML_VULKAN_RUN_TESTS)
|
||||
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
|
||||
|
@ -5712,9 +5717,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
|
|||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
||||
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
|
||||
|
||||
if (ggml_is_empty(node) || extra == nullptr) {
|
||||
if (ggml_is_empty(node) || !node->buffer) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -5966,7 +5969,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
}
|
||||
|
||||
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
|
||||
ggml_tensor_extra_gpu * extra = nullptr;
|
||||
ggml_backend_buffer * buf = nullptr;
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_ADD:
|
||||
|
@ -6002,7 +6005,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
buf = tensor->buffer;
|
||||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
|
@ -6012,7 +6015,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
buf = tensor->buffer;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
|
@ -6020,14 +6023,14 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
buf = tensor->buffer;
|
||||
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (extra == nullptr) {
|
||||
if (buf == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -6168,42 +6171,6 @@ static void ggml_vk_get_device_description(int device, char * description, size_
|
|||
|
||||
// device backend
|
||||
|
||||
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
||||
|
||||
struct ggml_backend_vk_buffer_context {
|
||||
vk_device_ref device;
|
||||
vk_buffer dev_buffer;
|
||||
ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
|
||||
size_t temp_tensor_extra_index = 0;
|
||||
std::string name;
|
||||
|
||||
ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
|
||||
device(device),
|
||||
dev_buffer(dev_buffer),
|
||||
name(name) {
|
||||
}
|
||||
|
||||
~ggml_backend_vk_buffer_context() {
|
||||
ggml_vk_destroy_buffer(dev_buffer);
|
||||
if (temp_tensor_extras != nullptr) {
|
||||
delete[] temp_tensor_extras;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu * ggml_vk_alloc_temp_tensor_extra() {
|
||||
if (temp_tensor_extras == nullptr) {
|
||||
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_VK_MAX_NODES];
|
||||
}
|
||||
|
||||
size_t alloc_index = temp_tensor_extra_index;
|
||||
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_VK_MAX_NODES;
|
||||
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
|
||||
extra->reset();
|
||||
|
||||
return extra;
|
||||
}
|
||||
};
|
||||
|
||||
static const char * ggml_backend_vk_buffer_get_name(ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
return ctx->name.c_str();
|
||||
|
@ -6228,51 +6195,37 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|||
|
||||
static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
|
||||
ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
|
||||
if (tensor->view_src != nullptr) {
|
||||
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
|
||||
GGML_ASSERT(tensor->view_src->extra != nullptr);
|
||||
tensor->extra = tensor->view_src->extra;
|
||||
} else {
|
||||
ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra();
|
||||
extra->buffer_gpu = ctx->dev_buffer;
|
||||
extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
|
||||
tensor->extra = extra;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
vk_buffer buf = extra->buffer_gpu.lock();
|
||||
|
||||
ggml_vk_buffer_write(buf, extra->offset + tensor->view_offs + offset, data, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
|
||||
|
||||
vk_buffer buf = extra->buffer_gpu.lock();
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_read(buf, extra->offset + tensor->view_offs + offset, data, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
if (ggml_backend_buffer_is_vk(src->buffer)) {
|
||||
ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
|
||||
vk_buffer src_buf = src_extra->buffer_gpu.lock();
|
||||
vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
|
||||
vk_buffer src_buf = src_buf_ctx->dev_buffer;
|
||||
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_copy(dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
|
||||
ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -6451,7 +6404,7 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
|||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
|
||||
|
@ -6464,9 +6417,9 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
|||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
}
|
||||
|
||||
vk_buffer buf = extra->buffer_gpu.lock();
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
|
||||
ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
|
@ -6474,7 +6427,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
|||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
|
||||
|
@ -6487,17 +6440,17 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
|||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
}
|
||||
|
||||
vk_buffer buf = extra->buffer_gpu.lock();
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size);
|
||||
ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
|
||||
ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
|
||||
|
@ -6510,10 +6463,10 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_
|
|||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
}
|
||||
|
||||
vk_buffer src_buf = src_extra->buffer_gpu.lock();
|
||||
vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
|
||||
vk_buffer src_buf = src_buf_ctx->dev_buffer;
|
||||
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
|
||||
ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -6936,10 +6889,10 @@ static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name)
|
|||
const size_t tensor_size = ggml_nbytes(tensor);
|
||||
tensor_data = malloc(tensor_size);
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_buffer buffer_gpu = extra->buffer_gpu.lock();
|
||||
ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
|
||||
vk_buffer buffer_gpu = buf_ctx->dev_buffer;
|
||||
ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);
|
||||
}
|
||||
|
||||
std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
|
||||
|
@ -7013,9 +6966,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
memcpy(src0_clone->data, src0->data, src0_size);
|
||||
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
} else if (ggml_backend_buffer_is_vk(src0->buffer)) {
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
vk_buffer buffer_gpu = extra->buffer_gpu.lock();
|
||||
uint64_t offset = extra->offset + src0->view_offs;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
||||
uint64_t offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
|
||||
for (int i3 = 0; i3 < src0->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < src0->ne[2]; i2++) {
|
||||
|
@ -7055,9 +7008,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
memcpy(src1_clone->data, src1->data, src1_size);
|
||||
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
} else if (ggml_backend_buffer_is_vk(src1->buffer)) {
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
vk_buffer buffer_gpu = extra->buffer_gpu.lock();
|
||||
uint64_t offset = extra->offset + src1->view_offs;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
||||
uint64_t offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
|
||||
for (int i3 = 0; i3 < src1->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < src1->ne[2]; i2++) {
|
||||
|
@ -7097,9 +7050,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
memcpy(src2_clone->data, src2->data, src2_size);
|
||||
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
} else if (ggml_backend_buffer_is_vk(src2->buffer)) {
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra;
|
||||
vk_buffer buffer_gpu = extra->buffer_gpu.lock();
|
||||
uint64_t offset = extra->offset + src2->view_offs;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context;
|
||||
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
||||
uint64_t offset = vk_tensor_offset(src2) + src2->view_offs;
|
||||
if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
|
||||
for (int i3 = 0; i3 < src2->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < src2->ne[2]; i2++) {
|
||||
|
@ -7154,7 +7107,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
} else if (tensor->op == GGML_OP_PAD) {
|
||||
tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]);
|
||||
} else if (tensor->op == GGML_OP_REPEAT) {
|
||||
tensor_clone = ggml_repeat(ggml_ctx, src0_clone, src1_clone);
|
||||
tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor);
|
||||
} else if (tensor->op == GGML_OP_ADD) {
|
||||
tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
|
||||
} else if (tensor->op == GGML_OP_ACC) {
|
||||
|
@ -7299,14 +7252,15 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
|
|||
size_t tensor_size = ggml_nbytes(tensor);
|
||||
tensor_data = malloc(tensor_size);
|
||||
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_buffer buffer_gpu = extra->buffer_gpu.lock();
|
||||
if (extra->offset + tensor->view_offs + tensor_size >= buffer_gpu->size) {
|
||||
tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs);
|
||||
vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
|
||||
uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;
|
||||
if (offset + tensor_size >= buffer_gpu->size) {
|
||||
tensor_size = buffer_gpu->size - offset;
|
||||
}
|
||||
|
||||
ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
|
||||
ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);
|
||||
}
|
||||
|
||||
float first_error_result = -1.0f;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue