Merge branch 'master' into feature/server-logs-improvment

This commit is contained in:
Pierrick HYMBERT 2024-02-24 20:47:58 +01:00
commit b9bebddcef
17 changed files with 1379 additions and 162 deletions

View file

@ -27,6 +27,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", },
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
{ "Q3_K_XS",LLAMA_FTYPE_MOSTLY_Q3_K_XS,"3-bit extra small quantization" , }, { "Q3_K_XS",LLAMA_FTYPE_MOSTLY_Q3_K_XS,"3-bit extra small quantization" , },
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },

View file

@ -1919,7 +1919,7 @@ struct llama_server_context
send_embedding(slot); send_embedding(slot);
slot.release(); slot.release();
slot.i_batch = -1; slot.i_batch = -1;
return true; continue;
} }
completion_token_output result; completion_token_output result;

View file

@ -1,36 +1,4 @@
# List of ongoing issues # List of ongoing issues
@bug @bug
Feature: Issues Feature: Issues
# Issue #5655 # No confirmed issue at the moment
Scenario: Multi users embeddings
Given a server listening on localhost:8080
And a model file stories260K.gguf
And a model alias tinyllama-2
And 42 as server seed
And 64 KV cache size
And 2 slots
And continuous batching
And embeddings extraction
Then the server is starting
Then the server is healthy
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated

View file

@ -8,6 +8,7 @@ Feature: Parallel
And 42 as server seed And 42 as server seed
And 64 KV cache size And 64 KV cache size
And 2 slots And 2 slots
And embeddings extraction
And continuous batching And continuous batching
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
@ -75,3 +76,48 @@ Feature: Parallel
Then the server is busy Then the server is busy
Then the server is idle Then the server is idle
Then all prompts are predicted Then all prompts are predicted
Scenario: Multi users embeddings
Given a prompt:
"""
Write a very long story about AI.
"""
And a prompt:
"""
Write another very long music lyrics.
"""
And a prompt:
"""
Write a very long poem.
"""
And a prompt:
"""
Write a very long joke.
"""
Given concurrent embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated
Scenario: Multi users OAI compatibility embeddings
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
And a prompt:
"""
What is the biggest US city ?
"""
And a prompt:
"""
What is the capital of Bulgaria ?
"""
And a model tinyllama-2
Given concurrent OAI embedding requests
Then the server is busy
Then the server is idle
Then all embeddings are generated

View file

@ -60,6 +60,19 @@ Feature: llama.cpp server
""" """
Then embeddings are generated Then embeddings are generated
Scenario: OAI Embeddings compatibility with multiple inputs
Given a model tinyllama-2
Given a prompt:
"""
In which country Paris is located ?
"""
And a prompt:
"""
Is Madrid the capital of Spain ?
"""
When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize
When tokenizing: When tokenizing:

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import collections
import json import json
import os import os
import re import re
@ -261,7 +262,7 @@ def step_a_prompt_prompt(context, prompt):
@step(u'concurrent completion requests') @step(u'concurrent completion requests')
@async_run_until_complete() @async_run_until_complete()
async def step_concurrent_completion_requests(context): async def step_concurrent_completion_requests(context):
await concurrent_completion_requests(context, await concurrent_requests(context,
request_completion, request_completion,
# prompt is inserted automatically # prompt is inserted automatically
context.base_url, context.base_url,
@ -275,7 +276,7 @@ async def step_concurrent_completion_requests(context):
@step(u'concurrent OAI completions requests') @step(u'concurrent OAI completions requests')
@async_run_until_complete @async_run_until_complete
async def step_oai_chat_completions(context): async def step_oai_chat_completions(context):
await concurrent_completion_requests(context, oai_chat_completions, await concurrent_requests(context, oai_chat_completions,
# user_prompt is inserted automatically # user_prompt is inserted automatically
context.system_prompt, context.system_prompt,
context.base_url, context.base_url,
@ -316,36 +317,58 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
@step(u'embeddings are computed for') @step(u'embeddings are computed for')
@async_run_until_complete @async_run_until_complete
async def step_compute_embedding(context): async def step_compute_embedding(context):
content = context.text context.embeddings = await request_embedding(context.text, base_url=context.base_url)
base_url = context.base_url
context.embeddings = await request_embedding(content, base_url)
@step(u'embeddings are generated') @step(u'embeddings are generated')
def step_assert_embeddings(context): def step_assert_embeddings(context):
if len(context.prompts) == 0:
assert_embeddings(context.embeddings) assert_embeddings(context.embeddings)
else:
assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
f"context.prompts={context.prompts}\n"
f"context.embeddings={context.embeddings}")
for embedding in context.embeddings:
context.prompts.pop()
assert_embeddings(embedding)
@step(u'an OAI compatible embeddings computation request for') @step(u'an OAI compatible embeddings computation request for')
def step_oai_compute_embedding(context): @async_run_until_complete
openai.api_key = 'nope' # openai client always expects an api_keu async def step_oai_compute_embeddings(context):
if context.user_api_key is not None: context.embeddings = await request_oai_embeddings(context.text,
openai.api_key = context.user_api_key base_url=context.base_url,
openai.api_base = f'{context.base_url}/v1' user_api_key=context.user_api_key,
embeddings = openai.Embedding.create( model=context.model)
model=context.model,
input=context.text,
) @step(u'an OAI compatible embeddings computation request for multiple inputs')
context.embeddings = embeddings @async_run_until_complete
async def step_oai_compute_embeddings_multiple_inputs(context):
context.embeddings = await request_oai_embeddings(context.prompts,
base_url=context.base_url,
user_api_key=context.user_api_key,
model=context.model)
@step(u'concurrent embedding requests') @step(u'concurrent embedding requests')
@async_run_until_complete() @async_run_until_complete()
async def step_concurrent_embedding_requests(context): async def step_concurrent_embedding_requests(context):
await concurrent_completion_requests(context, await concurrent_requests(context,
request_embedding, request_embedding,
# prompt is inserted automatically # prompt is inserted automatically
context.base_url) base_url=context.base_url)
@step(u'concurrent OAI embedding requests')
@async_run_until_complete()
async def step_concurrent_oai_embedding_requests(context):
await concurrent_requests(context,
request_oai_embeddings,
# prompt is inserted automatically
base_url=context.base_url,
async_client=True,
model=context.model)
@step(u'all embeddings are generated') @step(u'all embeddings are generated')
@ -401,7 +424,7 @@ def step_check_options_header_value(context, cors_header, cors_header_value):
assert context.options_response.headers[cors_header] == cors_header_value assert context.options_response.headers[cors_header] == cors_header_value
async def concurrent_completion_requests(context, f_completion, *args, **kwargs): async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts) n_prompts = len(context.prompts)
if context.debug: if context.debug:
print(f"starting {n_prompts} concurrent completion requests...") print(f"starting {n_prompts} concurrent completion requests...")
@ -565,7 +588,7 @@ async def oai_chat_completions(user_prompt,
return completion_response return completion_response
async def request_embedding(content, base_url): async def request_embedding(content, base_url=None):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/embedding', async with session.post(f'{base_url}/embedding',
json={ json={
@ -576,6 +599,46 @@ async def request_embedding(content, base_url):
return response_json['embedding'] return response_json['embedding']
async def request_oai_embeddings(input,
base_url=None, user_api_key=None,
model=None, async_client=False):
# openai client always expects an api_key
user_api_key = user_api_key if user_api_key is not None else 'nope'
if async_client:
origin = 'llama.cpp'
if user_api_key is not None:
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/v1/embeddings',
json={
"input": input,
"model": model,
},
headers=headers) as response:
assert response.status == 200, f"received status code not expected: {response.status}"
assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
response_json = await response.json()
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
assert response_json['object'] == 'list'
return response_json['data']
else:
openai.api_key = user_api_key
openai.api_base = f'{base_url}/v1'
oai_embeddings = openai.Embedding.create(
model=model,
input=input,
)
if isinstance(input, collections.abc.Sequence):
embeddings = []
for an_oai_embeddings in oai_embeddings.data:
embeddings.append(an_oai_embeddings.embedding)
else:
embeddings = oai_embeddings.data.embedding
return embeddings
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
content = completion_response['content'] content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n'] n_predicted = completion_response['timings']['predicted_n']

View file

@ -172,6 +172,7 @@
#endif #endif
typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsubss4(const int a, const int b) { static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a); const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b); const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
@ -196,6 +197,18 @@ static __device__ __forceinline__ int __vsub4(const int a, const int b) {
return __vsubss4(a, b); return __vsubss4(a, b);
} }
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
unsigned int c;
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
#pragma unroll
for (int i = 0; i < 4; ++i) {
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
}
return c;
}
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
c = __builtin_amdgcn_sdot4(a, b, c, false); c = __builtin_amdgcn_sdot4(a, b, c, false);
@ -518,6 +531,17 @@ typedef struct {
} block_iq3_xxs; } block_iq3_xxs;
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
#define QR3_XS 8
#define QI3_XS (QK_K / (4*QR3_XS))
typedef struct {
half d;
uint8_t qs[QK_K/4];
uint8_t qh[QK_K/32];
uint8_t signs[QK_K/8];
uint8_t scales[QK_K/64];
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_s block size/padding");
#define QR1_S 8 #define QR1_S 8
#define QI1_S (QK_K / (4*QR1_S)) #define QI1_S (QK_K / (4*QR1_S))
typedef struct { typedef struct {
@ -1700,6 +1724,74 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
}; };
static const __device__ uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
};
static const __device__ uint64_t iq1s_grid[512] = { static const __device__ uint64_t iq1s_grid[512] = {
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
@ -1973,6 +2065,32 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
} }
template<typename dst_t>
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq3_s * x = (const block_iq3_s *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
const uint8_t signs = x[i].signs[4*ib + il];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
#endif
}
template<typename dst_t> template<typename dst_t>
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@ -4717,6 +4835,41 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#endif #endif
} }
// TODO: don't use lookup table for signs
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
#if QK_K == 256
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
const int ib32 = iqs;
const uint8_t * qs = bq2->qs + 8*ib32;
const int8_t * q8 = bq8_1[ib32].qs;
int sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
q8 += 8;
}
const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
return d * sumi;
#else
assert(false);
return 0.f;
#endif
#else
assert(false);
return 0.f;
#endif
}
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if QK_K == 256 #if QK_K == 256
@ -6849,6 +7002,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y); dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
} }
template<typename dst_t>
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t> template<typename dst_t>
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K; const int nb = k / QK_K;
@ -6904,6 +7063,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq1_s_cuda; return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda; return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32: case GGML_TYPE_F32:
return convert_unary_cuda<float>; return convert_unary_cuda<float>;
default: default:
@ -6943,6 +7104,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq1_s_cuda; return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda; return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return convert_unary_cuda<half>; return convert_unary_cuda<half>;
default: default:
@ -8688,6 +8851,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_RDNA2 ? 128 : 64; return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
@ -8713,6 +8877,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
return max_compute_capability >= CC_VOLTA ? 128 : 64; return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
return 64; return 64;
@ -8818,6 +8983,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1> mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
@ -11541,7 +11710,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
} }
ggml_type a_type = a->type; ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) { a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) { if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false; return false;
} }

View file

@ -61,6 +61,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@ -85,6 +86,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@ -105,6 +107,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@ -122,6 +125,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@ -139,6 +143,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
GGML_METAL_KERNEL_TYPE_ROPE_F32, GGML_METAL_KERNEL_TYPE_ROPE_F32,
@ -452,6 +457,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@ -476,6 +482,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@ -496,6 +503,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@ -513,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@ -530,6 +539,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
@ -1347,6 +1357,7 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@ -1483,6 +1494,12 @@ static bool ggml_metal_graph_compute(
nth1 = 16; nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
} break; } break;
case GGML_TYPE_IQ3_S:
{
nth0 = 4;
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
} break;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
{ {
nth0 = 4; nth0 = 4;
@ -1537,8 +1554,8 @@ static bool ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_IQ3_XXS) { else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
const int mem_size = 256*4+128; const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
@ -1640,6 +1657,7 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@ -1779,6 +1797,12 @@ static bool ggml_metal_graph_compute(
nth1 = 16; nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
} break; } break;
case GGML_TYPE_IQ3_S:
{
nth0 = 4;
nth1 = 16;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
} break;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
{ {
nth0 = 4; nth0 = 4;
@ -1849,8 +1873,8 @@ static bool ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src2t == GGML_TYPE_IQ3_XXS) { else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
const int mem_size = 256*4+128; const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
@ -1900,6 +1924,7 @@ static bool ggml_metal_graph_compute(
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;

View file

@ -2525,6 +2525,20 @@ typedef struct {
} block_iq3_xxs; } block_iq3_xxs;
// 98 bytes / block for QK_K = 256, so 3.0625 bpw // 98 bytes / block for QK_K = 256, so 3.0625 bpw
// 3.4375 bpw
#if QK_K == 64
#define IQ3S_N_SCALE 2
#else
#define IQ3S_N_SCALE QK_K/64
#endif
typedef struct {
half d;
uint8_t qs[QK_K/4];
uint8_t qh[QK_K/32];
uint8_t signs[QK_K/8];
uint8_t scales[IQ3S_N_SCALE];
} block_iq3_s;
typedef struct { typedef struct {
half d; half d;
uint8_t qs[QK_K/8]; uint8_t qs[QK_K/8];
@ -3795,6 +3809,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
}; };
constexpr constant static uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
};
#define NGRID_IQ1S 512 #define NGRID_IQ1S 512
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = { constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@ -4361,6 +4442,136 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
} }
void kernel_mul_mv_iq3_s_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne10,
constant int64_t & ne12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
const int ix = tiisg;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
const int ibl = ib32 / (QK_K / 32);
const int ib = ib32 % (QK_K / 32);
device const block_iq3_s * xr = x + ibl;
device const uint8_t * qs = xr->qs + 8 * ib;
device const uint8_t * qh = xr->qh + ib;
device const uint8_t * sc = xr->scales + (ib/2);
device const uint8_t * signs = xr->signs + 4 * ib;
device const half * dh = &xr->d;
for (int row = 0; row < N_DST; row++) {
const float db = dh[0];
const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
}
}
sumf[row] += d * (sum[0] + sum[1]);
dh += nb*sizeof(block_iq3_s)/2;
qs += nb*sizeof(block_iq3_s);
qh += nb*sizeof(block_iq3_s);
sc += nb*sizeof(block_iq3_s);
signs += nb*sizeof(block_iq3_s);
}
y4 += 32 * 32;
}
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
}
}
}
[[host_name("kernel_mul_mv_iq3_s_f32")]]
kernel void kernel_mul_mv_iq3_s_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_iq1_s_f32_impl( void kernel_mul_mv_iq1_s_f32_impl(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
@ -4952,6 +5163,31 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
} }
} }
template <typename type4x4>
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
}
template <typename type4x4> template <typename type4x4>
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@ -5525,6 +5761,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
@ -5566,6 +5803,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
@ -5619,6 +5857,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>; template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
@ -6589,6 +6828,71 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
sgitg); sgitg);
} }
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
kernel void kernel_mul_mv_id_iq3_s_f32(
device const char * ids,
device const char * src1,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
device const char * src00,
device const char * src01,
device const char * src02,
device const char * src03,
device const char * src04,
device const char * src05,
device const char * src06,
device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
kernel_mul_mv_iq3_s_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
ne01,
ne02,
ne10,
ne12,
ne0,
ne1,
r2,
r3,
shared_values,
tgpig,
tiisg,
sgitg);
}
[[host_name("kernel_mul_mv_id_iq1_s_f32")]] [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
kernel void kernel_mul_mv_id_iq1_s_f32( kernel void kernel_mul_mv_id_iq1_s_f32(
device const char * ids, device const char * ids,

View file

@ -3505,6 +3505,73 @@ static const uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
}; };
static const uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
};
#define NGRID_IQ2XXS 512 #define NGRID_IQ2XXS 512
static const uint64_t iq1s_grid[NGRID_IQ2XXS] = { static const uint64_t iq1s_grid[NGRID_IQ2XXS] = {
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
@ -3736,6 +3803,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
} }
} }
// ====================== 3.3125 bpw (de)-quantization
void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * qs = x[i].qs;
const uint8_t * qh = x[i].qh;
const uint8_t * signs = x[i].signs;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f;
const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
y += 8;
}
qs += 8;
signs += 4;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
y += 8;
}
qh += 2;
qs += 8;
signs += 4;
}
}
}
// ====================== 1.5625 bpw (de)-quantization // ====================== 1.5625 bpw (de)-quantization
void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) { void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
@ -8806,6 +8916,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
#endif #endif
#if defined (__AVX2__) || defined (__ARM_NEON)
static const int8_t keven_signs_q2xs[1024] = { static const int8_t keven_signs_q2xs[1024] = {
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
@ -8840,6 +8951,7 @@ static const int8_t keven_signs_q2xs[1024] = {
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
}; };
#endif
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
@ -9327,6 +9439,202 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
#endif #endif
} }
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq3_s * restrict x = vx;
const block_q8_K * restrict y = vy;
const int nb = n / QK_K;
#if defined(__ARM_NEON)
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
};
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
const uint8x16_t mask2 = vld1q_u8(k_mask2);
uint8x16x2_t vs;
ggml_int8x16x4_t q3s;
ggml_int8x16x4_t q8b;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict qs = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
const int8_t * restrict q8 = y[i].qs;
int sumi1 = 0, sumi2 = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)],
iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]};
const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)],
iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]};
const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)],
iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]};
const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)],
iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]};
qs += 16;
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);
q3s.val[0] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_0))), vreinterpretq_s8_u8(vs.val[0]));
q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1]));
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16)));
vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
vs.val[0] = vceqq_u8(vs.val[0], mask2);
vs.val[1] = vceqq_u8(vs.val[1], mask2);
signs += 4;
q3s.val[2] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[0], vreinterpretq_u8_u32(aux32x4_2))), vreinterpretq_s8_u8(vs.val[0]));
q3s.val[3] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_3))), vreinterpretq_s8_u8(vs.val[1]));
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
}
sumf += d*(sumi1 + sumi2);
}
*s = 0.25f * sumf;
#elif defined(__AVX2__)
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
};
static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
};
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict qs = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
const int8_t * restrict q8 = y[i].qs;
__m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q2_1 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+0] << 1) & 256)],
iq3xs_grid[qs[6] | ((qh[ib32+0] << 2) & 256)],
iq3xs_grid[qs[5] | ((qh[ib32+0] << 3) & 256)],
iq3xs_grid[qs[4] | ((qh[ib32+0] << 4) & 256)],
iq3xs_grid[qs[3] | ((qh[ib32+0] << 5) & 256)],
iq3xs_grid[qs[2] | ((qh[ib32+0] << 6) & 256)],
iq3xs_grid[qs[1] | ((qh[ib32+0] << 7) & 256)],
iq3xs_grid[qs[0] | ((qh[ib32+0] << 8) & 256)]);
qs += 8;
const __m256i q2_2 = _mm256_set_epi32(iq3xs_grid[qs[7] | ((qh[ib32+1] << 1) & 256)],
iq3xs_grid[qs[6] | ((qh[ib32+1] << 2) & 256)],
iq3xs_grid[qs[5] | ((qh[ib32+1] << 3) & 256)],
iq3xs_grid[qs[4] | ((qh[ib32+1] << 4) & 256)],
iq3xs_grid[qs[3] | ((qh[ib32+1] << 5) & 256)],
iq3xs_grid[qs[2] | ((qh[ib32+1] << 6) & 256)],
iq3xs_grid[qs[1] | ((qh[ib32+1] << 7) & 256)],
iq3xs_grid[qs[0] | ((qh[ib32+1] << 8) & 256)]);
qs += 8;
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
signs += 4;
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
sumi1 = _mm256_add_epi32(sumi1, p1);
sumi2 = _mm256_add_epi32(sumi2, p2);
}
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
}
*s = 0.25f * hsum_float_8(accumf);
#else
float sumf = 0.f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict qs = x[i].qs;
const uint8_t * restrict qh = x[i].qh;
const uint8_t * restrict signs = x[i].signs;
const int8_t * restrict q8 = y[i].qs;
int32_t bsum = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
int32_t sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
}
q8 += 8;
}
qs += 8;
signs += 4;
bsum += sumi * ls1;
sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
}
q8 += 8;
}
qs += 8;
signs += 4;
bsum += sumi * ls2;
}
sumf += d * bsum;
}
*s = 0.25f * sumf;
#endif
}
#ifdef __AVX2__ #ifdef __AVX2__
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x); const __m256i ax = _mm256_sign_epi8(x, x);
@ -9523,6 +9831,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
float sumf = 0; float sumf = 0;
for (int ib = 0; ib < nb; ib += 2) { for (int ib = 0; ib < nb; ib += 2) {
q4bits.val[0] = vld1q_u8(x[ib+0].qs); q4bits.val[0] = vld1q_u8(x[ib+0].qs);
q4bits.val[1] = vld1q_u8(x[ib+1].qs); q4bits.val[1] = vld1q_u8(x[ib+1].qs);
q8b.val[0] = vld1q_s8(y[ib+0].qs); q8b.val[0] = vld1q_s8(y[ib+0].qs);
@ -10239,14 +10548,15 @@ typedef struct {
uint16_t * neighbours; uint16_t * neighbours;
} iq3_entry_t; } iq3_entry_t;
static iq3_entry_t iq3_data[1] = { static iq3_entry_t iq3_data[2] = {
{NULL, NULL, NULL},
{NULL, NULL, NULL}, {NULL, NULL, NULL},
}; };
static inline int iq3_data_index(int grid_size) { static inline int iq3_data_index(int grid_size) {
(void)grid_size; (void)grid_size;
GGML_ASSERT(grid_size == 256); GGML_ASSERT(grid_size == 256 || grid_size == 512);
return 0; return grid_size == 256 ? 0 : 1;
} }
static int iq3_compare_func(const void * left, const void * right) { static int iq3_compare_func(const void * left, const void * right) {
@ -10278,9 +10588,44 @@ void iq3xs_init_impl(int grid_size) {
3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610, 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992, 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
}; };
static const uint16_t kgrid_512[512] = {
0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34,
37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77,
80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142,
145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210,
217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288,
291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393,
395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514,
516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576,
577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653,
655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727,
728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833,
840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977,
989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047,
1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103,
1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199,
1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296,
1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415,
1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561,
1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648,
1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761,
1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877,
1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068,
2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177,
2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269,
2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520,
2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634,
2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805,
2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083,
3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276,
3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591,
3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729,
3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032,
};
const int kmap_size = 4096; const int kmap_size = 4096;
const int nwant = 2; const int nwant = grid_size == 256 ? 2 : 3;
const uint16_t * kgrid = kgrid_256; const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
uint32_t * kgrid_q3xs; uint32_t * kgrid_q3xs;
int * kmap_q3xs; int * kmap_q3xs;
uint16_t * kneighbors_q3xs; uint16_t * kneighbors_q3xs;
@ -10377,7 +10722,7 @@ void iq3xs_init_impl(int grid_size) {
} }
void iq3xs_free_impl(int grid_size) { void iq3xs_free_impl(int grid_size) {
GGML_ASSERT(grid_size == 256); GGML_ASSERT(grid_size == 256 || grid_size == 512);
const int gindex = iq3_data_index(grid_size); const int gindex = iq3_data_index(grid_size);
if (iq3_data[gindex].grid) { if (iq3_data[gindex].grid) {
free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL; free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
@ -10410,9 +10755,10 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u
return grid_index; return grid_index;
} }
static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n,
const float * restrict quant_weights) {
const int gindex = iq3_data_index(256); const int gindex = iq3_data_index(grid_size);
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
const int * kmap_q3xs = iq3_data[gindex].map; const int * kmap_q3xs = iq3_data[gindex].map;
@ -10426,9 +10772,23 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
const int kMaxQ = 8; const int kMaxQ = 8;
const int nbl = n/256; const int nbl = n/QK_K;
ggml_fp16_t * dh;
uint8_t * qs;
int block_size;
if (grid_size == 256) {
block_iq3_xxs * y = vy; block_iq3_xxs * y = vy;
dh = &y->d;
qs = y->qs;
block_size = sizeof(block_iq3_xxs);
} else {
block_iq3_s * y = vy;
dh = &y->d;
qs = y->qs;
block_size = sizeof(block_iq3_s);
}
int quant_size = block_size - sizeof(ggml_fp16_t);
float scales[QK_K/32]; float scales[QK_K/32];
float weight[32]; float weight[32];
@ -10439,20 +10799,21 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
bool is_on_grid[8]; bool is_on_grid[8];
bool is_on_grid_aux[8]; bool is_on_grid_aux[8];
uint8_t block_signs[8]; uint8_t block_signs[8];
uint8_t q3[3*(QK_K/8)]; uint8_t q3[3*(QK_K/8)+QK_K/32];
uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4); uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
uint8_t * qh = q3 + 3*(QK_K/8);
for (int ibl = 0; ibl < nbl; ++ibl) { for (int ibl = 0; ibl < nbl; ++ibl) {
y[ibl].d = GGML_FP32_TO_FP16(0.f); dh[0] = GGML_FP32_TO_FP16(0.f);
memset(q3, 0, 3*QK_K/8); memset(q3, 0, 3*QK_K/8+QK_K/32);
float max_scale = 0; float max_scale = 0;
const float * xbl = x + QK_K*ibl; const float * xbl = x + QK_K*ibl;
float sumx2 = 0; float sumx2 = 0;
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
float sigma2 = sumx2/QK_K; float sigma2 = 2*sumx2/QK_K;
for (int ib = 0; ib < QK_K/32; ++ib) { for (int ib = 0; ib < QK_K/32; ++ib) {
const float * xb = xbl + 32*ib; const float * xb = xbl + 32*ib;
@ -10570,7 +10931,13 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
printf("\n"); printf("\n");
GGML_ASSERT(false); GGML_ASSERT(false);
} }
if (grid_size == 256) {
q3[8*ib+k] = grid_index; q3[8*ib+k] = grid_index;
} else {
q3[8*ib+k] = grid_index & 255;
qh[ib] |= ((grid_index >> 8) << k);
}
} }
scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21); scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
GGML_ASSERT(scale >= 0); GGML_ASSERT(scale >= 0);
@ -10579,63 +10946,25 @@ static void quantize_row_iq3_xxs_impl(const float * restrict x, void * restrict
} }
if (!max_scale) { if (!max_scale) {
memset(y[ibl].qs, 0, 3*QK_K/8); memset(qs, 0, quant_size);
dh += block_size/sizeof(ggml_fp16_t);
qs += block_size;
continue; continue;
} }
float d = max_scale/31; float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d); dh[0] = GGML_FP32_TO_FP16(d * 1.0125f); // small improvement via this fudge factor
float id = 1/d; float id = 1/d;
float sumqx = 0, sumq2 = 0;
for (int ib = 0; ib < QK_K/32; ++ib) { for (int ib = 0; ib < QK_K/32; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]-1)); int l = nearest_int(0.5f*(id*scales[ib]-1));
l = MAX(0, MIN(15, l)); l = MAX(0, MIN(15, l));
scales_and_signs[ib] |= ((uint32_t)l << 28); scales_and_signs[ib] |= ((uint32_t)l << 28);
if (false) {
const float * xb = xbl + 32*ib;
if (quant_weights) {
const float * qw = quant_weights + QK_K*ibl + 32*ib;
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
} else {
for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
} }
const float db = 0.25f * d * (1 + 2*l); memcpy(qs, q3, quant_size);
for (int k = 0; k < 8; ++k) {
const int8_t * signs = keven_signs_q2xs + 8*((scales_and_signs[ib] >> 7*(k/2)) & 127) + 4*(k%2); dh += block_size/sizeof(ggml_fp16_t);
const float * xk = xb + 4*k; qs += block_size;
const float * wk = weight + 4*k;
//const uint8_t * grid = (const uint8_t *)(kgrid_q3xs + q3[8*ib+k]);
const uint8_t * grid = (const uint8_t *)(iq3xxs_grid + q3[8*ib+k]);
float best_mse = 0; int best_index = q3[8*ib+k];
for (int j = 0; j < 4; ++j) {
float diff = db * grid[j] * signs[j] - xk[j];
best_mse += wk[j] * diff * diff;
}
for (int idx = 0; idx < 256; ++idx) {
//grid = (const uint8_t *)(kgrid_q3xs + idx);
grid = (const uint8_t *)(iq3xxs_grid + idx);
float mse = 0;
for (int j = 0; j < 4; ++j) {
float diff = db * grid[j] * signs[j] - xk[j];
mse += wk[j] * diff * diff;
}
if (mse < best_mse) {
best_mse = mse; best_index = idx;
}
}
q3[8*ib+k] = best_index;
//grid = (const uint8_t *)(kgrid_q3xs + best_index);
grid = (const uint8_t *)(iq3xxs_grid + best_index);
for (int j = 0; j < 4; ++j) {
float q = db * grid[j] * signs[j];
sumqx += wk[j] * q * xk[j];
sumq2 += wk[j] * q * q;
}
}
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
}
}
memcpy(y[ibl].qs, q3, 3*QK_K/8);
} }
} }
@ -10645,7 +10974,7 @@ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row,
int nblock = n_per_row/QK_K; int nblock = n_per_row/QK_K;
char * qrow = (char *)dst; char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) { for (int row = 0; row < nrow; ++row) {
quantize_row_iq3_xxs_impl(src, qrow, n_per_row, quant_weights); quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
src += n_per_row; src += n_per_row;
qrow += nblock*sizeof(block_iq3_xxs); qrow += nblock*sizeof(block_iq3_xxs);
} }
@ -10660,9 +10989,226 @@ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) {
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) { void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
quantize_row_iq3_xxs_impl(x, y, k, NULL); quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
} }
static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n,
const float * restrict quant_weights,
float * scales,
float * weight,
float * xval,
int8_t * L,
int8_t * Laux,
float * waux,
bool * is_on_grid,
bool * is_on_grid_aux,
uint8_t * block_signs) {
const int gindex = iq3_data_index(512);
const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
const int * kmap_q3xs = iq3_data[gindex].map;
const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
//GGML_ASSERT(quant_weights && "missing quantization weights");
GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
GGML_ASSERT(n%QK_K == 0);
const int kMaxQ = 8;
const int nbl = n/QK_K;
block_iq3_s * y = vy;
const int bs4 = block_size/4;
const int bs8 = block_size/8;
for (int ibl = 0; ibl < nbl; ++ibl) {
memset(&y[ibl], 0, sizeof(block_iq3_s));
y[ibl].d = GGML_FP32_TO_FP16(0.f);
uint8_t * qs = y[ibl].qs;
uint8_t * qh = y[ibl].qh;
uint8_t * signs = y[ibl].signs;
float max_scale = 0;
const float * xbl = x + QK_K*ibl;
float sumx2 = 0;
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
float sigma2 = 2*sumx2/QK_K;
for (int ib = 0; ib < QK_K/block_size; ++ib) {
const float * xb = xbl + block_size*ib;
if (quant_weights) {
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
} else {
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
}
for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);
for (int k = 0; k < bs8; ++k) {
uint8_t s = 0;
for (int i = 0; i < 8; ++i) {
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
else {
xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
}
}
block_signs[k] = s;
}
float max = xval[0];
for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
if (!max) {
scales[ib] = 0;
continue;
}
float best = 0;
float scale = max/(2*kMaxQ-1);
for (int is = -15; is <= 15; ++is) {
float id = (2*kMaxQ-1+is*0.2f)/max;
float this_scale = 1/id;
for (int k = 0; k < bs4; ++k) {
for (int i = 0; i < 4; ++i) {
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
}
uint16_t u = 0;
for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
int grid_index = kmap_q3xs[u];
is_on_grid_aux[k] = true;
if (grid_index < 0) {
is_on_grid_aux[k] = false;
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
}
}
float sumqx = 0, sumq2 = 0;
for (int i = 0; i < block_size; ++i) {
float w = weight[i];
float q = 2*Laux[i] + 1;
sumqx += w*xval[i]*q;
sumq2 += w*q*q;
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
scale = sumqx/sumq2; best = scale*sumqx;
for (int i = 0; i < block_size; ++i) L[i] = Laux[i];
for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];
}
}
int n_not_ongrid = 0;
for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
if (n_not_ongrid > 0 && scale > 0) {
float id = 1/scale;
for (int k = 0; k < bs4; ++k) {
if (is_on_grid[k]) continue;
uint16_t u = 0;
for (int i = 0; i < 4; ++i) {
int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
l = MAX(0, MIN(kMaxQ-1, l));
u |= (l << 3*i);
}
int grid_index = kmap_q3xs[u];
if (grid_index < 0) {
const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
}
const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
}
float sumqx = 0, sumq2 = 0;
for (int i = 0; i < block_size; ++i) {
float w = weight[i];
float q = 2*L[i] + 1;
sumqx += w*xval[i]*q;
sumq2 += w*q*q;
}
if (sumq2 > 0) scale = sumqx/sumq2;
}
if (scale < 0) {
// This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
// and correspondingly flip quant signs.
scale = -scale;
for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];
}
for (int k = 0; k < bs4; ++k) {
uint16_t u = 0;
for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
int grid_index = kmap_q3xs[u];
if (grid_index < 0) {
printf("Oops: found point %u not on grid:", u);
for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
printf("\n");
GGML_ASSERT(false);
}
qs[k] = grid_index & 255;
qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
}
qs += bs4;
for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];
signs += bs8;
GGML_ASSERT(scale >= 0);
scales[ib] = scale;
max_scale = MAX(max_scale, scale);
}
if (!max_scale) {
continue;
}
float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d);
float id = 1/d;
for (int ib = 0; ib < QK_K/block_size; ib += 2) {
int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
l1 = MAX(0, MIN(15, l1));
int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
l2 = MAX(0, MIN(15, l2));
y[ibl].scales[ib/2] = l1 | (l2 << 4);
}
}
}
#define IQ3S_BLOCK_SIZE 32
size_t quantize_iq3_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
GGML_ASSERT(n_per_row%QK_K == 0);
int nblock = n_per_row/QK_K;
float scales[QK_K/IQ3S_BLOCK_SIZE];
float weight[IQ3S_BLOCK_SIZE];
float xval[IQ3S_BLOCK_SIZE];
int8_t L[IQ3S_BLOCK_SIZE];
int8_t Laux[IQ3S_BLOCK_SIZE];
float waux[IQ3S_BLOCK_SIZE];
bool is_on_grid[IQ3S_BLOCK_SIZE/4];
bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
src += n_per_row;
qrow += nblock*sizeof(block_iq3_s);
}
return nrow * nblock * sizeof(block_iq3_s);
}
void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_iq3_s * restrict y = vy;
quantize_row_iq3_s_reference(x, y, k);
}
void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) {
assert(k % QK_K == 0);
quantize_iq3_s(x, y, 1, k, NULL, NULL);
}
// =================================== 1.5 bpw =================================================== // =================================== 1.5 bpw ===================================================
static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid, static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,

View file

@ -191,6 +191,21 @@ typedef struct {
} block_iq3_xxs; } block_iq3_xxs;
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
// 3.4375 bpw
#if QK_K == 64
#define IQ3S_N_SCALE 2
#else
#define IQ3S_N_SCALE QK_K/64
#endif
typedef struct {
ggml_fp16_t d;
uint8_t qs[QK_K/4];
uint8_t qh[QK_K/32];
uint8_t signs[QK_K/8];
uint8_t scales[IQ3S_N_SCALE];
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
typedef struct { typedef struct {
ggml_fp16_t d; ggml_fp16_t d;
uint8_t qs[QK_K/8]; uint8_t qs[QK_K/8];
@ -226,6 +241,7 @@ void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGM
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k); void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k);
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k);
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k); void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k);
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
@ -242,6 +258,7 @@ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
// Dequantization // Dequantization
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
@ -262,6 +279,7 @@ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_
void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
// Dot product // Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@ -280,6 +298,7 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// //
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
@ -289,6 +308,7 @@ size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row,
size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_iq3_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);

31
ggml.c
View file

@ -678,6 +678,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_IQ3_S] = {
.type_name = "iq3_s",
.blck_size = QK_K,
.type_size = sizeof(block_iq3_s),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq3_s,
.from_float = quantize_row_iq3_s,
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference,
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_IQ1_S] = { [GGML_TYPE_IQ1_S] = {
.type_name = "iq1_s", .type_name = "iq1_s",
.blck_size = QK_K, .blck_size = QK_K,
@ -2304,6 +2316,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
} }
@ -7738,6 +7751,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
{ {
ggml_compute_forward_add_q_f32(params, dst); ggml_compute_forward_add_q_f32(params, dst);
} break; } break;
@ -8017,6 +8031,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
{ {
ggml_compute_forward_add1_q_f32(params, dst); ggml_compute_forward_add1_q_f32(params, dst);
} break; } break;
@ -8141,6 +8156,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
default: default:
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
@ -11039,6 +11055,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
{ {
ggml_compute_forward_out_prod_q_f32(params, dst); ggml_compute_forward_out_prod_q_f32(params, dst);
} break; } break;
@ -11227,6 +11244,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
default: default:
{ {
GGML_ASSERT(false); GGML_ASSERT(false);
@ -11429,6 +11447,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
{ {
ggml_compute_forward_get_rows_q(params, dst); ggml_compute_forward_get_rows_q(params, dst);
} break; } break;
@ -12129,6 +12148,7 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
@ -12212,6 +12232,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K:
case GGML_TYPE_I8: case GGML_TYPE_I8:
case GGML_TYPE_I16: case GGML_TYPE_I16:
@ -19463,6 +19484,7 @@ void ggml_quantize_init(enum ggml_type type) {
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break; case GGML_TYPE_IQ1_S: iq2xs_init_impl(type); break;
case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
default: // nothing default: // nothing
break; break;
} }
@ -19737,6 +19759,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); result = quantize_iq3_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows); GGML_ASSERT(result == row_size * nrows);
} break; } break;
case GGML_TYPE_IQ3_S:
{
GGML_ASSERT(start % QK_K == 0);
GGML_ASSERT(start % n_per_row == 0);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_iq3_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
{ {
GGML_ASSERT(start % QK_K == 0); GGML_ASSERT(start % QK_K == 0);

2
ggml.h
View file

@ -350,6 +350,7 @@ extern "C" {
GGML_TYPE_IQ3_XXS = 18, GGML_TYPE_IQ3_XXS = 18,
GGML_TYPE_IQ1_S = 19, GGML_TYPE_IQ1_S = 19,
GGML_TYPE_IQ4_NL = 20, GGML_TYPE_IQ4_NL = 20,
GGML_TYPE_IQ3_S = 21,
GGML_TYPE_I8, GGML_TYPE_I8,
GGML_TYPE_I16, GGML_TYPE_I16,
GGML_TYPE_I32, GGML_TYPE_I32,
@ -389,6 +390,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
}; };
// available tensor operations: // available tensor operations:

View file

@ -2545,6 +2545,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
default: default:
{ {
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@ -2890,6 +2891,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
default: return "unknown, may not work"; default: return "unknown, may not work";
} }
@ -10544,6 +10547,12 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_Q3_K : GGML_TYPE_IQ3_XXS; new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_Q3_K : GGML_TYPE_IQ3_XXS;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_S && qs.model.hparams.n_gqa() >= 4) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
} }
@ -10575,13 +10584,17 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
new_type = GGML_TYPE_Q8_0; new_type = GGML_TYPE_Q8_0;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {
new_type = GGML_TYPE_Q2_K; new_type = GGML_TYPE_IQ3_XXS;
}
} else if (name.find("attn_q.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) {
new_type = GGML_TYPE_IQ3_XXS;
} }
} else if (name.find("ffn_down") != std::string::npos) { } else if (name.find("ffn_down") != std::string::npos) {
auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
int i_layer = info.first, n_layer = info.second; int i_layer = info.first, n_layer = info.second;
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K; if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) { else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
@ -10592,6 +10605,10 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
: arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
: GGML_TYPE_Q3_K; : GGML_TYPE_Q3_K;
} }
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
(qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
} }
@ -10623,7 +10640,8 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
if (qs.model.hparams.n_expert == 8) { if (qs.model.hparams.n_expert == 8) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL ||
ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
new_type = GGML_TYPE_Q5_K; new_type = GGML_TYPE_Q5_K;
} }
} else { } else {
@ -10631,29 +10649,32 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_Q3_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_Q3_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K;
} }
} else { } else {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
} }
} }
else if (name.find("attn_qkv.weight") != std::string::npos) { else if (name.find("attn_qkv.weight") != std::string::npos) {
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
new_type = GGML_TYPE_Q4_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
} }
else if (name.find("ffn_gate") != std::string::npos) { else if (name.find("ffn_gate") != std::string::npos) {
auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str()); auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
int i_layer = info.first, n_layer = info.second; int i_layer = info.first, n_layer = info.second;
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && !use_more_bits(i_layer, n_layer)) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
new_type = GGML_TYPE_Q2_K; new_type = GGML_TYPE_IQ3_XXS;
} }
++qs.i_ffn_gate; ++qs.i_ffn_gate;
} }
else if (name.find("ffn_up") != std::string::npos) { else if (name.find("ffn_up") != std::string::npos) {
auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str()); auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
int i_layer = info.first, n_layer = info.second; int i_layer = info.first, n_layer = info.second;
if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && !use_more_bits(i_layer, n_layer)) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
new_type = GGML_TYPE_Q2_K; new_type = GGML_TYPE_IQ3_XXS;
} }
++qs.i_ffn_up; ++qs.i_ffn_up;
} }
@ -10673,7 +10694,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K ||
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS ||
new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) { new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || new_type == GGML_TYPE_IQ3_S) {
int nx = tensor->ne[0]; int nx = tensor->ne[0];
int ny = tensor->ne[1]; int ny = tensor->ne[1];
if (nx % QK_K != 0) { if (nx % QK_K != 0) {
@ -10688,6 +10709,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_Q3_K: new_type = GGML_TYPE_IQ4_NL; break;
@ -10719,7 +10741,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// K-quants // K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K_S: case LLAMA_FTYPE_MOSTLY_Q2_K_S:
case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_XS: case LLAMA_FTYPE_MOSTLY_Q3_K_XS: quantized_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_S:
case LLAMA_FTYPE_MOSTLY_Q3_K_M: case LLAMA_FTYPE_MOSTLY_Q3_K_M:
case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break;
@ -10733,6 +10755,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break;
case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S; break;
case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: quantized_type = GGML_TYPE_IQ4_NL; break;
case LLAMA_FTYPE_MOSTLY_IQ3_S: quantized_type = GGML_TYPE_IQ3_S; break;
case LLAMA_FTYPE_MOSTLY_IQ3_M: quantized_type = GGML_TYPE_IQ3_S; break;
default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
} }

View file

@ -102,6 +102,8 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
}; };

View file

@ -1918,7 +1918,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_Q6_K, GGML_TYPE_Q6_K,
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S,
}; };
// unary ops // unary ops

View file

@ -151,6 +151,7 @@ int main(int argc, char * argv[]) {
const float max_quantization_error = const float max_quantization_error =
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR; type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
failed = !(total_error < max_quantization_error); failed = !(total_error < max_quantization_error);
num_failed += failed; num_failed += failed;
@ -167,7 +168,8 @@ int main(int argc, char * argv[]) {
const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data()); const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
type == GGML_TYPE_IQ3_XXS ? MAX_DOT_PRODUCT_ERROR_LOWBIT : MAX_DOT_PRODUCT_ERROR; type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT
: MAX_DOT_PRODUCT_ERROR;
failed = !(vec_dot_error < max_allowed_error); failed = !(vec_dot_error < max_allowed_error);
num_failed += failed; num_failed += failed;
if (failed || verbose) { if (failed || verbose) {