Merge branch 'master' into sycl_readme_update

This commit is contained in:
OuadiElfarouki 2024-03-20 12:33:23 +00:00
commit fa7c6ddd30
8 changed files with 143 additions and 107 deletions

View file

@ -37,6 +37,7 @@ When targetting **Intel CPUs**, it is recommended to use llama.cpp for [x86_64]
## News ## News
- 2024.3 - 2024.3
- New base line is ready: [tag b2437](https://github.com/ggerganov/llama.cpp/tree/b2437).
- Support multiple cards: **--split-mode**: [none|layer]; not support [row], it's on developing. - Support multiple cards: **--split-mode**: [none|layer]; not support [row], it's on developing.
- Support to assign main GPU by **--main-gpu**, replace $GGML_SYCL_DEVICE. - Support to assign main GPU by **--main-gpu**, replace $GGML_SYCL_DEVICE.
- Support detecting all GPUs with level-zero and same top **Max compute units**. - Support detecting all GPUs with level-zero and same top **Max compute units**.
@ -300,15 +301,16 @@ Similar to the native `sycl-ls`, available SYCL devices can be queried as follow
``` ```
A example of such log in a system with 1 *intel CPU* and 1 *intel GPU* can look like the following: A example of such log in a system with 1 *intel CPU* and 1 *intel GPU* can look like the following:
``` ```
found 4 SYCL devices: found 6 SYCL devices:
Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3, | | | |Compute |Max compute|Max work|Max sub| |
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136 |ID| Device Type| Name|capability|units |group |group |Global mem size|
Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2, |--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------|
max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280 | 0|[level_zero:gpu:0]| Intel(R) Arc(TM) A770 Graphics| 1.3| 512| 1024| 32| 16225243136|
Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0, | 1|[level_zero:gpu:1]| Intel(R) UHD Graphics 770| 1.3| 32| 512| 32| 53651849216|
max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280 | 2| [opencl:gpu:0]| Intel(R) Arc(TM) A770 Graphics| 3.0| 512| 1024| 32| 16225243136|
Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0, | 3| [opencl:gpu:1]| Intel(R) UHD Graphics 770| 3.0| 32| 512| 32| 53651849216|
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136 | 4| [opencl:cpu:0]| 13th Gen Intel(R) Core(TM) i7-13700K| 3.0| 24| 8192| 64| 67064815616|
| 5| [opencl:acc:0]| Intel(R) FPGA Emulation Device| 1.2| 24|67108864| 64| 67064815616|
``` ```
|Attribute|Note| |Attribute|Note|
@ -318,10 +320,33 @@ found 4 SYCL devices:
4. Launch inference 4. Launch inference
For instance, in order to target the SYCL device with *ID*=0 *(log from previous command)*, we simply specify `GGML_SYCL_DEVICE=0`. There are two device selection modes:
- Single device: Use one device target specified by the user.
- Multiple devices: Automatically select the devices with the same largest Max compute-units.
|Device selection|Parameter|
|-|-|
|Single device|--split-mode none --main-gpu DEVICE_ID |
|Multiple devices|--split-mode layer (default)|
Examples:
- Use device 0:
```sh ```sh
GGML_SYCL_DEVICE=0 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0
```
or run by script:
```sh
./examples/sycl/run_llama2.sh 0
```
- Use multiple devices:
```sh
ZES_ENABLE_SYSMAN=1 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer
``` ```
Otherwise, you can run the script: Otherwise, you can run the script:
@ -333,6 +358,15 @@ Otherwise, you can run the script:
*Notes:* *Notes:*
- By default, `mmap` is used to read the model file. In some cases, it causes runtime hang issues. Please disable it by passing `--no-mmap` to the `/bin/main` if faced with the issue. - By default, `mmap` is used to read the model file. In some cases, it causes runtime hang issues. Please disable it by passing `--no-mmap` to the `/bin/main` if faced with the issue.
- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow:
```sh
detect 1 SYCL GPUs: [0] with top Max compute units:512
```
Or
```sh
use 1 SYCL GPUs: [0] with Max compute units:512
```
## Windows ## Windows
@ -387,7 +421,7 @@ a. Download & install cmake for Windows: https://cmake.org/download/
b. Download & install mingw-w64 make for Windows provided by w64devkit b. Download & install mingw-w64 make for Windows provided by w64devkit
- Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases). - Download the 1.19.0 version of [w64devkit](https://github.com/skeeto/w64devkit/releases/download/v1.19.0/w64devkit-1.19.0.zip).
- Extract `w64devkit` on your pc. - Extract `w64devkit` on your pc.
@ -439,15 +473,17 @@ build\bin\ls-sycl-device.exe
The output of this command in a system with 1 *intel CPU* and 1 *intel GPU* would look like the following: The output of this command in a system with 1 *intel CPU* and 1 *intel GPU* would look like the following:
``` ```
found 4 SYCL devices: found 6 SYCL devices:
Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3, | | | |Compute |Max compute|Max work|Max sub| |
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136 |ID| Device Type| Name|capability|units |group |group |Global mem size|
Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2, |--|------------------|---------------------------------------------|----------|-----------|--------|-------|---------------|
max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280 | 0|[level_zero:gpu:0]| Intel(R) Arc(TM) A770 Graphics| 1.3| 512| 1024| 32| 16225243136|
Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0, | 1|[level_zero:gpu:1]| Intel(R) UHD Graphics 770| 1.3| 32| 512| 32| 53651849216|
max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280 | 2| [opencl:gpu:0]| Intel(R) Arc(TM) A770 Graphics| 3.0| 512| 1024| 32| 16225243136|
Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0, | 3| [opencl:gpu:1]| Intel(R) UHD Graphics 770| 3.0| 32| 512| 32| 53651849216|
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136 | 4| [opencl:cpu:0]| 13th Gen Intel(R) Core(TM) i7-13700K| 3.0| 24| 8192| 64| 67064815616|
| 5| [opencl:acc:0]| Intel(R) FPGA Emulation Device| 1.2| 24|67108864| 64| 67064815616|
``` ```
|Attribute|Note| |Attribute|Note|
@ -455,13 +491,31 @@ found 4 SYCL devices:
|compute capability 1.3|Level-zero running time, recommended | |compute capability 1.3|Level-zero running time, recommended |
|compute capability 3.0|OpenCL running time, slower than level-zero in most cases| |compute capability 3.0|OpenCL running time, slower than level-zero in most cases|
4. Launch inference 4. Launch inference
Set device ID=0 with `set GGML_SYCL_DEVICE=0` to target the Level-zero intel GPU and run the main: There are two device selection modes:
- Single device: Use one device assigned by user.
- Multiple devices: Automatically choose the devices with the same biggest Max compute units.
|Device selection|Parameter|
|-|-|
|Single device|--split-mode none --main-gpu DEVICE_ID |
|Multiple devices|--split-mode layer (default)|
Examples:
- Use device 0:
``` ```
set GGML_SYCL_DEVICE=0 build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm none -mg 0
build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 ```
- Use multiple devices:
```
build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm layer
``` ```
Otherwise, run the following wrapper script: Otherwise, run the following wrapper script:
@ -472,9 +526,17 @@ Otherwise, run the following wrapper script:
Note: Note:
- By default, `mmap` is used to read the model file. In some cases, it causes runtime hang issues. Please disable it by passing `--no-mmap` to the `main.exe` if faced with the issue. - By default, `mmap` is used to read the model file. In some cases, it causes runtime hang issues. Please disable it by passing `--no-mmap` to the `main.exe` if faced with the issue.
- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow:
```sh
detect 1 SYCL GPUs: [0] with top Max compute units:512
```
Or
```sh
use 1 SYCL GPUs: [0] with Max compute units:512
```
## Environment Variables ## Environment Variable
#### Build #### Build
@ -490,7 +552,6 @@ Note:
|Name|Value|Function| |Name|Value|Function|
|-|-|-| |-|-|-|
|GGML_SYCL_DEVICE|0 (default) or 1|Set the device id used. Check the device ids by default running output|
|GGML_SYCL_DEBUG|0 (default) or 1|Enable log function by macro: GGML_SYCL_DEBUG| |GGML_SYCL_DEBUG|0 (default) or 1|Enable log function by macro: GGML_SYCL_DEBUG|
|ZES_ENABLE_SYSMAN| 0 (default) or 1|Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer| |ZES_ENABLE_SYSMAN| 0 (default) or 1|Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer|
@ -506,7 +567,7 @@ Note:
## Q&A ## Q&A
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`. - Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
- Potential cause: Unavailable oneAPI installation or not set ENV variables. - Potential cause: Unavailable oneAPI installation or not set ENV variables.
- Solution: Install *oneAPI base toolkit* and enable its ENV through: `source /opt/intel/oneapi/setvars.sh`. - Solution: Install *oneAPI base toolkit* and enable its ENV through: `source /opt/intel/oneapi/setvars.sh`.
@ -525,5 +586,11 @@ Note:
sudo usermod -aG render $USER sudo usermod -aG render $USER
sudo usermod -aG video $USER sudo usermod -aG video $USER
``` ```
Otherwise, please double-check the GPU driver installation steps.
Otherwise, please double-check the installation GPU steps. ### **GitHub contribution**:
Please add the **[SYCL]** prefix/tag in issues/PRs titles to help the SYCL-team check/address them without delay.
## Todo
- Support row layer split for multiple card runs.

View file

@ -5,15 +5,14 @@ import sys
import time import time
import traceback import traceback
from contextlib import closing from contextlib import closing
from subprocess import TimeoutExpired
import psutil
def before_scenario(context, scenario): def before_scenario(context, scenario):
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
if context.debug: if context.debug:
print("DEBUG=ON\n") print("DEBUG=ON")
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n") print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
port = 8080 port = 8080
if 'PORT' in os.environ: if 'PORT' in os.environ:
port = int(os.environ['PORT']) port = int(os.environ['PORT'])
@ -27,60 +26,40 @@ def after_scenario(context, scenario):
return return
if scenario.status == "failed": if scenario.status == "failed":
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n") print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n")
if os.path.isfile('llama.log'): if os.path.isfile('llama.log'):
with closing(open('llama.log', 'r')) as f: with closing(open('llama.log', 'r')) as f:
for line in f: for line in f:
print(line) print(line)
if not is_server_listening(context.server_fqdn, context.server_port): if not is_server_listening(context.server_fqdn, context.server_port):
print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n") print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
if not pid_exists(context.server_process.pid): if context.server_process.poll() is not None:
assert False, f"Server not running pid={context.server_process.pid} ..." assert False, f"Server not running pid={context.server_process.pid} ..."
server_graceful_shutdown(context) server_graceful_shutdown(context) # SIGINT
# Wait few for socket to free up try:
time.sleep(0.05) context.server_process.wait(0.5)
except TimeoutExpired:
print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...")
context.server_process.kill() # SIGKILL
context.server_process.wait()
attempts = 0 while is_server_listening(context.server_fqdn, context.server_port):
while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port):
server_kill(context)
time.sleep(0.1) time.sleep(0.1)
attempts += 1 except Exception:
if attempts > 5: print("ignoring error in after_scenario:")
server_kill_hard(context) traceback.print_exc(file=sys.stdout)
except:
exc = sys.exception()
print("error in after scenario: \n")
print(exc)
print("*** print_tb: \n")
traceback.print_tb(exc.__traceback__, file=sys.stdout)
def server_graceful_shutdown(context): def server_graceful_shutdown(context):
print(f"shutting down server pid={context.server_process.pid} ...\n") print(f"shutting down server pid={context.server_process.pid} ...")
if os.name == 'nt': if os.name == 'nt':
os.kill(context.server_process.pid, signal.CTRL_C_EVENT) interrupt = signal.CTRL_C_EVENT
else: else:
os.kill(context.server_process.pid, signal.SIGINT) interrupt = signal.SIGINT
context.server_process.send_signal(interrupt)
def server_kill(context):
print(f"killing server pid={context.server_process.pid} ...\n")
context.server_process.kill()
def server_kill_hard(context):
pid = context.server_process.pid
path = context.server_path
print(f"Server dangling exits, hard killing force {pid}={path}...\n")
try:
psutil.Process(pid).kill()
except psutil.NoSuchProcess:
return False
return True
def is_server_listening(server_fqdn, server_port): def is_server_listening(server_fqdn, server_port):
@ -88,14 +67,5 @@ def is_server_listening(server_fqdn, server_port):
result = sock.connect_ex((server_fqdn, server_port)) result = sock.connect_ex((server_fqdn, server_port))
_is_server_listening = result == 0 _is_server_listening = result == 0
if _is_server_listening: if _is_server_listening:
print(f"server is listening on {server_fqdn}:{server_port}...\n") print(f"server is listening on {server_fqdn}:{server_port}...")
return _is_server_listening return _is_server_listening
def pid_exists(pid):
try:
psutil.Process(pid)
except psutil.NoSuchProcess:
return False
return True

View file

@ -35,9 +35,9 @@ Feature: llama.cpp server
And metric llamacpp:tokens_predicted is <n_predicted> And metric llamacpp:tokens_predicted is <n_predicted>
Examples: Prompts Examples: Prompts
| prompt | n_predict | re_content | n_prompt | n_predicted | truncated | | prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
| I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not | | I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
| Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids)+ | 46 | 64 | not | | Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids\|Anna\|forest)+ | 46 | 64 | not |
Scenario: Completion prompt truncated Scenario: Completion prompt truncated
Given a prompt: Given a prompt:
@ -48,7 +48,7 @@ Feature: llama.cpp server
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""" """
And a completion request with no api error And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
And the completion is truncated And the completion is truncated
And 109 prompt tokens are processed And 109 prompt tokens are processed
@ -65,9 +65,9 @@ Feature: llama.cpp server
And the completion is <truncated> truncated And the completion is <truncated> truncated
Examples: Prompts Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated | | model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
| llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not | | llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | |
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize

View file

@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port):
def step_download_hf_model(context, hf_file, hf_repo): def step_download_hf_model(context, hf_file, hf_repo):
context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
if context.debug: if context.debug:
print(f"model file: {context.model_file}\n") print(f"model file: {context.model_file}")
@step('a model file {model_file}') @step('a model file {model_file}')
@ -137,9 +137,12 @@ def step_start_server(context):
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
max_attempts *= 2 max_attempts *= 2
addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
family, typ, proto, _, sockaddr = addrs[0]
while True: while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(family, typ, proto)) as sock:
result = sock.connect_ex((context.server_fqdn, context.server_port)) result = sock.connect_ex(sockaddr)
if result == 0: if result == 0:
print("\x1b[33;46mserver started!\x1b[0m") print("\x1b[33;46mserver started!\x1b[0m")
return return
@ -209,7 +212,7 @@ async def step_request_completion(context, api_error):
user_api_key=context.user_api_key) user_api_key=context.user_api_key)
context.tasks_result.append(completion) context.tasks_result.append(completion)
if context.debug: if context.debug:
print(f"Completion response: {completion}\n") print(f"Completion response: {completion}")
if expect_api_error: if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos):
prompt += context.prompt_junk_suffix prompt += context.prompt_junk_suffix
if context.debug: if context.debug:
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```")
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
context.n_prompts = len(context.prompts) context.n_prompts = len(context.prompts)
@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos):
@async_run_until_complete @async_run_until_complete
async def step_oai_chat_completions(context, api_error): async def step_oai_chat_completions(context, api_error):
if context.debug: if context.debug:
print(f"Submitting OAI compatible completions request...\n") print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised' expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(), completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt, context.system_prompt,
@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context):
embedding1 = np.array(embeddings[i]) embedding1 = np.array(embeddings[i])
embedding2 = np.array(embeddings[j]) embedding2 = np.array(embeddings[j])
if context.debug: if context.debug:
print(f"embedding1: {embedding1[-8:]}\n") print(f"embedding1: {embedding1[-8:]}")
print(f"embedding2: {embedding2[-8:]}\n") print(f"embedding2: {embedding2[-8:]}")
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
msg = f"Similarity between {i} and {j}: {similarity:.10f}" msg = f"Similarity between {i} and {j}: {similarity:.10f}"
if context.debug: if context.debug:
print(f"{msg}\n") print(f"{msg}")
assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context):
metrics_raw = await metrics_response.text() metrics_raw = await metrics_response.text()
metric_exported = False metric_exported = False
if context.debug: if context.debug:
print(f"/metrics answer:\n{metrics_raw}\n") print(f"/metrics answer:\n{metrics_raw}")
context.metrics = {} context.metrics = {}
for metric in parser.text_string_to_metric_families(metrics_raw): for metric in parser.text_string_to_metric_families(metrics_raw):
match metric.name: match metric.name:
@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
last_match = end last_match = end
highlighted += content[last_match:] highlighted += content[last_match:]
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"Checking completion response: {highlighted}\n") print(f"Checking completion response: {highlighted}")
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
if expected_predicted_n and expected_predicted_n > 0: if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
async def gather_tasks_results(context): async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks) n_tasks = len(context.concurrent_tasks)
if context.debug: if context.debug:
print(f"Waiting for all {n_tasks} tasks results...\n") print(f"Waiting for all {n_tasks} tasks results...")
for task_no in range(n_tasks): for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop()) context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.tasks_result) n_completions = len(context.tasks_result)
@ -959,7 +962,7 @@ async def wait_for_health_status(context,
slots_processing=None, slots_processing=None,
expected_slots=None): expected_slots=None):
if context.debug: if context.debug:
print(f"Starting checking for health for expected_health_status={expected_health_status}\n") print(f"Starting checking for health for expected_health_status={expected_health_status}")
interval = 0.5 interval = 0.5
counter = 0 counter = 0
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
@ -1048,8 +1051,6 @@ def start_server_background(context):
if 'LLAMA_SERVER_BIN_PATH' in os.environ: if 'LLAMA_SERVER_BIN_PATH' in os.environ:
context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
server_listen_addr = context.server_fqdn server_listen_addr = context.server_fqdn
if os.name == 'nt':
server_listen_addr = '0.0.0.0'
server_args = [ server_args = [
'--host', server_listen_addr, '--host', server_listen_addr,
'--port', context.server_port, '--port', context.server_port,
@ -1088,7 +1089,7 @@ def start_server_background(context):
server_args.append('--verbose') server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ: if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"]) server_args.extend(['--log-format', "text"])
print(f"starting server with: {context.server_path} {server_args}\n") print(f"starting server with: {context.server_path} {server_args}")
flags = 0 flags = 0
if 'nt' == os.name: if 'nt' == os.name:
flags |= subprocess.DETACHED_PROCESS flags |= subprocess.DETACHED_PROCESS

View file

@ -3,5 +3,4 @@ behave~=1.2.6
huggingface_hub~=0.20.3 huggingface_hub~=0.20.3
numpy~=1.24.4 numpy~=1.24.4
openai~=0.25.0 openai~=0.25.0
psutil~=5.9.8
prometheus-client~=0.20.0 prometheus-client~=0.20.0

View file

@ -371,6 +371,7 @@ static json oaicompat_completion_params_parse(
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false); llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z); llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
llama_params["n_keep"] = json_value(body, "n_keep", 0);
if (body.count("grammar") != 0) { if (body.count("grammar") != 0) {
llama_params["grammar"] = json_value(body, "grammar", json::object()); llama_params["grammar"] = json_value(body, "grammar", json::object());

View file

@ -6,8 +6,6 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
set GGML_SYCL_DEVICE=0
rem set GGML_SYCL_DEBUG=1
.\build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0 .\build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0

View file

@ -13,7 +13,7 @@
extern "C" { extern "C" {
#endif #endif
#define GGML_SYCL_MAX_DEVICES 16 #define GGML_SYCL_MAX_DEVICES 48
#define GGML_SYCL_NAME "SYCL" #define GGML_SYCL_NAME "SYCL"
GGML_API void ggml_init_sycl(void); GGML_API void ggml_init_sycl(void);