server: tests: adding OAI compatible embedding with multiple inputs
This commit is contained in:
parent
466987eb7b
commit
04f4cbbd9e
2 changed files with 40 additions and 2 deletions
|
@ -60,6 +60,19 @@ Feature: llama.cpp server
|
|||
"""
|
||||
Then embeddings are generated
|
||||
|
||||
Scenario: OAI Embeddings compatibility with multiple inputs
|
||||
Given a model tinyllama-2
|
||||
Given a prompt:
|
||||
"""
|
||||
In which country Paris is located ?
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Is Madrid the capital of Spain ?
|
||||
"""
|
||||
When an OAI compatible embeddings computation request for multiple inputs
|
||||
Then embeddings are generated
|
||||
|
||||
|
||||
Scenario: Tokenize / Detokenize
|
||||
When tokenizing:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -321,7 +322,15 @@ async def step_compute_embedding(context):
|
|||
|
||||
@step(u'embeddings are generated')
|
||||
def step_assert_embeddings(context):
|
||||
assert_embeddings(context.embeddings)
|
||||
if len(context.prompts) == 0:
|
||||
assert_embeddings(context.embeddings)
|
||||
else:
|
||||
assert len(context.embeddings) == len(context.prompts), (f"unexpected response:\n"
|
||||
f"context.prompts={context.prompts}\n"
|
||||
f"context.embeddings={context.embeddings}")
|
||||
for embedding in context.embeddings:
|
||||
context.prompts.pop()
|
||||
assert_embeddings(embedding)
|
||||
|
||||
|
||||
@step(u'an OAI compatible embeddings computation request for')
|
||||
|
@ -333,6 +342,15 @@ async def step_oai_compute_embeddings(context):
|
|||
model=context.model)
|
||||
|
||||
|
||||
@step(u'an OAI compatible embeddings computation request for multiple inputs')
|
||||
@async_run_until_complete
|
||||
async def step_oai_compute_embeddings_multiple_inputs(context):
|
||||
context.embeddings = await request_oai_embeddings(context.prompts,
|
||||
base_url=context.base_url,
|
||||
user_api_key=context.user_api_key,
|
||||
model=context.model)
|
||||
|
||||
|
||||
@step(u'concurrent embedding requests')
|
||||
@async_run_until_complete()
|
||||
async def step_concurrent_embedding_requests(context):
|
||||
|
@ -607,10 +625,17 @@ async def request_oai_embeddings(input,
|
|||
else:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}/v1'
|
||||
embeddings = openai.Embedding.create(
|
||||
oai_embeddings = openai.Embedding.create(
|
||||
model=model,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if isinstance(input, collections.abc.Sequence):
|
||||
embeddings = []
|
||||
for an_oai_embeddings in oai_embeddings.data:
|
||||
embeddings.append(an_oai_embeddings.embedding)
|
||||
else:
|
||||
embeddings = oai_embeddings.data.embedding
|
||||
return embeddings
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue