文本生成推理
文本生成推理 (TGI) 是一个用于部署和提供大型语言模型 (LLM) 服务的工具包。TGI 为最流行的开源 LLM 提供了高性能文本生成能力。它具有量化、张量并行、token 流式传输、连续批处理、Flash Attention、引导等多种功能。
使用官方 Docker 容器是开始使用 TGI 的最简单方法。
部署
- Mistral-7B
- Mixtral-8X7B
- Mixtral-8X22B
model=mistralai/Mistral-7B-Instruct-v0.3
model=mistralai/Mixtral-8x22B-Instruct-v0.1
model=mistralai/Mixtral-8x22B-Instruct-v0.1
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
ghcr.io/huggingface/text-generation-inference:2.0.3 \
--model-id $model
这将启动一个 TGI 实例,暴露一个类似 OpenAI 的 API,详见 API 部分 的文档。
确保将 HUGGING_FACE_HUB_TOKEN
环境变量设置为您的 Hugging Face 用户访问令牌。要使用 Mistral 模型,您必须首先访问相应的模型页面并填写小表格。然后您将自动获得模型的访问权限。
如果模型不适合您的 GPU,您还可以使用量化方法(AWQ、GPTQ 等)。您可以在 其文档 中找到所有 TGI 启动选项。
使用 API
使用兼容聊天的端点
TGI 支持 Messages API,它与 Mistral 和 OpenAI Chat Completion API 兼容。
- 使用 MistralClient
- 使用 OpenAI Client
- 使用 cURL
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
# init the client but point it to TGI
client = MistralClient(api_key="-", endpoint="http://127.0.0.1:8080")
chat_response = client.chat(
model="-",
messages=[
ChatMessage(role="user", content="What is the best French cheese?")
]
)
print(chat_response.choices[0].message.content)
from openai import OpenAI
# init the client but point it to TGI
client = OpenAI(api_key="-", base_url="http://127.0.0.1:8080/v1")
chat_response = client.chat.completions.create(
model="-",
messages=[
{"role": "user", "content": "What is deep learning?"}
]
)
print(chat_response)
curl http://127.0.0.1:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is deep learning?"
}
]
}' \
-H 'Content-Type: application/json'
使用生成端点
如果您想对发送到服务器的内容有更多控制,可以使用 generate
端点。在这种情况下,您负责使用正确的模板和停止 token 格式化 prompt。
- 使用 Python
- 使用 JavaScript
- 使用 cURL
# Make sure to install the huggingface_hub package before
from huggingface_hub import InferenceClient
client = InferenceClient(model="http://127.0.0.1:8080")
client.text_generation(prompt="What is Deep Learning?")
async function query() {
const response = await fetch(
'http://127.0.0.1:8080/generate',
{
method: 'POST',
headers: { 'Content-Type': 'application/json'},
body: JSON.stringify({
'inputs': 'What is Deep Learning?'
})
}
);
}
query().then((response) => {
console.log(JSON.stringify(response));
});
curl 127.0.0.1:8080/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'