server: tests: adding OAI compatible embedding with multiple inputs
This commit is contained in:
parent
466987eb7b
commit
04f4cbbd9e
2 changed files with 40 additions and 2 deletions
|
@ -60,6 +60,19 @@ Feature: llama.cpp server
|
||||||
"""
|
"""
|
||||||
Then embeddings are generated
|
Then embeddings are generated
|
||||||
|
|
||||||
|
Scenario: OAI Embeddings compatibility with multiple inputs
|
||||||
|
Given a model tinyllama-2
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
In which country Paris is located ?
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Is Madrid the capital of Spain ?
|
||||||
|
"""
|
||||||
|
When an OAI compatible embeddings computation request for multiple inputs
|
||||||
|
Then embeddings are generated
|
||||||
|
|
||||||
|
|
||||||
Scenario: Tokenize / Detokenize
|
Scenario: Tokenize / Detokenize
|
||||||
When tokenizing:
|
When tokenizing:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import collections
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -321,7 +322,15 @@ async def step_compute_embedding(context):
|
||||||
|
|
||||||
@step(u'embeddings are generated')
|
@step(u'embeddings are generated')
|
||||||
def step_assert_embeddings(context):
|
def step_assert_embeddings(context):
|
||||||
|
if len(context.prompts) == 0:
|
||||||
assert_embeddings(context.embeddings)
|
assert_embeddings(context.embeddings)
|
||||||
|
else:
|
||||||
|
assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
|
||||||
|
f"context.prompts={context.prompts}\n"
|
||||||
|
f"context.embeddings={context.embeddings}")
|
||||||
|
for embedding in context.embeddings:
|
||||||
|
context.prompts.pop()
|
||||||
|
assert_embeddings(embedding)
|
||||||
|
|
||||||
|
|
||||||
@step(u'an OAI compatible embeddings computation request for')
|
@step(u'an OAI compatible embeddings computation request for')
|
||||||
|
@ -333,6 +342,15 @@ async def step_oai_compute_embeddings(context):
|
||||||
model=context.model)
|
model=context.model)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'an OAI compatible embeddings computation request for multiple inputs')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_oai_compute_embeddings_multiple_inputs(context):
|
||||||
|
context.embeddings = await request_oai_embeddings(context.prompts,
|
||||||
|
base_url=context.base_url,
|
||||||
|
user_api_key=context.user_api_key,
|
||||||
|
model=context.model)
|
||||||
|
|
||||||
|
|
||||||
@step(u'concurrent embedding requests')
|
@step(u'concurrent embedding requests')
|
||||||
@async_run_until_complete()
|
@async_run_until_complete()
|
||||||
async def step_concurrent_embedding_requests(context):
|
async def step_concurrent_embedding_requests(context):
|
||||||
|
@ -607,10 +625,17 @@ async def request_oai_embeddings(input,
|
||||||
else:
|
else:
|
||||||
openai.api_key = user_api_key
|
openai.api_key = user_api_key
|
||||||
openai.api_base = f'{base_url}/v1'
|
openai.api_base = f'{base_url}/v1'
|
||||||
embeddings = openai.Embedding.create(
|
oai_embeddings = openai.Embedding.create(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(input, collections.abc.Sequence):
|
||||||
|
embeddings = []
|
||||||
|
for an_oai_embeddings in oai_embeddings.data:
|
||||||
|
embeddings.append(an_oai_embeddings.embedding)
|
||||||
|
else:
|
||||||
|
embeddings = oai_embeddings.data.embedding
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue