Minor fixes
This commit is contained in:
parent
b6264542b7
commit
5e9403342b
1 changed files with 14 additions and 15 deletions
|
@ -18,8 +18,6 @@
|
|||
#endif
|
||||
|
||||
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
||||
#define STRINGIFY(x) STRINGIFY2(x)
|
||||
#define STRINGIFY2(x) #x
|
||||
|
||||
#define QK4_0 32
|
||||
#define QR4_0 2
|
||||
|
@ -182,7 +180,7 @@ static const std::string program_source_head = R"(
|
|||
|
||||
|
||||
static const std::string program_dequantize_row_q4_0 =
|
||||
program_source_head+'\n'+MULTILINE_QUOTE(
|
||||
MULTILINE_QUOTE(
|
||||
layout(local_size_x = 1, local_size_y = 1) in;
|
||||
layout(binding = 0) buffer tensorBlockQ4_0D { float16_t x_d[]; };
|
||||
layout(binding = 1) buffer tensorBlockQ4_0QS { uint8_t x_qs[]; };
|
||||
|
@ -209,7 +207,7 @@ void ggml_vk_dequantize_row_q4_0(const void *x_, float *y, int k) {
|
|||
static const int qk = QK4_0;
|
||||
const unsigned nb = k / qk;
|
||||
const unsigned y_size = nb*qk;
|
||||
const static auto spirv = compileSource(program_dequantize_row_q4_0);
|
||||
const static auto spirv = compileSource(program_source_head+program_dequantize_row_q4_0);
|
||||
|
||||
const auto x = reinterpret_cast<const block_q4_0*>(x_);
|
||||
|
||||
|
@ -230,7 +228,7 @@ void ggml_vk_dequantize_row_q4_0(const void *x_, float *y, int k) {
|
|||
|
||||
|
||||
static const std::string program_dequantize_row_q4_1 =
|
||||
program_source_head+'\n'+MULTILINE_QUOTE(
|
||||
MULTILINE_QUOTE(
|
||||
layout(local_size_x = 1, local_size_y = 1) in;
|
||||
layout(binding = 0) buffer tensorBlockQ4_0D { float16_t x_d[]; };
|
||||
layout(binding = 1) buffer tensorBlockQ4_0M { float16_t x_m[]; };
|
||||
|
@ -259,7 +257,7 @@ void ggml_vk_dequantize_row_q4_1(const void *x_, float *y, int k) {
|
|||
static const int qk = QK4_1;
|
||||
const unsigned nb = k / qk;
|
||||
const unsigned y_size = nb*qk;
|
||||
const static auto spirv = compileSource(program_dequantize_row_q4_1);
|
||||
const static auto spirv = compileSource(program_source_head+program_dequantize_row_q4_1);
|
||||
|
||||
const auto x = reinterpret_cast<const block_q4_1*>(x_);
|
||||
|
||||
|
@ -281,7 +279,7 @@ void ggml_vk_dequantize_row_q4_1(const void *x_, float *y, int k) {
|
|||
|
||||
|
||||
static const std::string program_abmath =
|
||||
program_source_head+'\n'+MULTILINE_QUOTE(
|
||||
MULTILINE_QUOTE(
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
|
@ -293,24 +291,25 @@ layout(push_constant) uniform PushConstants {
|
|||
layout(local_size_x = 1) in;
|
||||
layout(binding = 0) buffer tensorInA { float inA[]; };
|
||||
layout(binding = 1) buffer tensorInB { float inB[]; };
|
||||
layout(binding = 2) buffer tensorout { float out[]; };
|
||||
layout(binding = 2) buffer tensorOut { float out_[]; };
|
||||
|
||||
|
||||
void main() {
|
||||
const int i = int(gl_GlobalInvocationID.x);
|
||||
|
||||
out[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+(i ROW_OP)];
|
||||
out_[pcs.outOff+i] = inA[pcs.inAOff+i] MATH_OP inB[pcs.inBOff+(i ROW_OP)];
|
||||
}
|
||||
);
|
||||
|
||||
template<char mathOP>
|
||||
void ggml_vk_abmath(const std::shared_ptr<kp::Tensor>& inA, uint32_t inAOff,
|
||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
uint32_t row = 0) {
|
||||
const static auto spirv = compileSource("#define MATH_OP "+std::string(1, mathOP)+"\n"
|
||||
"#define ROW_OP "+(row?"% pcs.row":"")+"\n"
|
||||
+program_abmath);
|
||||
const std::shared_ptr<kp::Tensor>& inB, uint32_t inBOff,
|
||||
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
|
||||
uint32_t row = 0) {
|
||||
const static auto spirv = compileSource(program_source_head+
|
||||
"#define MATH_OP "+std::string(1, mathOP)+"\n"
|
||||
"#define ROW_OP "+(row?"% pcs.row":"")+'\n'+
|
||||
program_abmath);
|
||||
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff, row;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue