Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Any, Dict, List, Literal, Optional, Union
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModel, BertModel, BertTokenizer
from sse_starlette.sse import ServerSentEvent, EventSourceResponse


@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
Expand All @@ -34,6 +34,7 @@ async def lifespan(app: FastAPI): # collects GPU memory
allow_headers=["*"],
)


class ModelCard(BaseModel):
id: str
object: str = "model"
Expand Down Expand Up @@ -109,8 +110,8 @@ async def create_chat_completion(request: ChatCompletionRequest):
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i + 1].content])

if request.stream:
generate = predict(query, history, request.model)
Expand Down Expand Up @@ -154,7 +155,6 @@ async def predict(query: str, history: List[List[str]], model_id: str):
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))


choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
Expand All @@ -165,6 +165,43 @@ async def predict(query: str, history: List[List[str]], model_id: str):
yield '[DONE]'


class EmbeddingsRequest(BaseModel):
model: str
input: Union[str, List[str]]


class EmbeddingsData(BaseModel):
object: str = "embedding"
embedding: List[float]
index: int = 0


class EmbeddingsResponse(BaseModel):
object: str = "list"
data: List[EmbeddingsData]
model: str
usage: dict


@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
async def create_chat_completion(request: EmbeddingsRequest):
global bert_model, bert_tokenizer
encoded_input = bert_tokenizer(request.input, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = bert_model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
return EmbeddingsResponse(
data=[EmbeddingsData(embedding=_.tolist(), index=i) for i, _ in enumerate(sentence_embeddings)],
model=request.model,
usage={},
)


def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
Expand All @@ -173,5 +210,8 @@ async def predict(query: str, history: List[List[str]], model_id: str):
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model.eval()

bert_name = 'GanymedeNil/text2vec-large-chinese'
bert_model = BertModel.from_pretrained(bert_name)
bert_tokenizer = BertTokenizer.from_pretrained(bert_name, model_max_length=512)
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)