跳到主要内容

批量推理

准备并上传您的批量请求

一个批量请求由一系列 API 请求组成。单个请求的结构包括

  • 一个唯一的 custom_id,用于标识每个请求并在完成后引用结果
  • 一个包含消息信息的 body 对象

以下是构建批量请求的示例

{"custom_id": "0", "body": {"max_tokens": 100, "messages": [{"role": "user", "content": "What is the best French cheese?"}]}}
{"custom_id": "1", "body": {"max_tokens": 100, "messages": [{"role": "user", "content": "What is the best French wine?"}]}}

将您的批量请求保存到 .jsonl 文件中。保存后,您可以上传您的批量输入文件,以确保在启动批量处理时能正确引用它

from mistralai import Mistral
import os

api_key = os.environ["MISTRAL_API_KEY"]

client = Mistral(api_key=api_key)

batch_data = client.files.upload(
file={
"file_name": "test.jsonl",
"content": open("test.jsonl", "rb")
},
purpose = "batch"
)

创建新的批量作业

创建一个新的批量作业,它将被排队等待处理。

  • input_files:批量输入文件 ID 的列表。
  • model:每个批量请求只能使用一个模型(例如 codestral-latest)。但是,如果您想比较不同模型的输出,可以在相同文件上运行多个批量请求。
  • endpoint:我们目前支持 `/v1/embeddings`、`/v1/chat/completions`、`/v1/fim/completions`、`/v1/moderations`、`/v1/chat/moderations`。
  • metadata:批量请求的可选自定义元数据。
created_job = client.batch.jobs.create(
input_files=[batch_data.id],
model="mistral-small-latest",
endpoint="/v1/chat/completions",
metadata={"job_type": "testing"}
)

获取批量作业详情

retrieved_job = client.batch.jobs.get(job_id=created_job.id)

获取批量作业结果

output_file_stream = client.files.download(file_id=retrieved_job.output_file)

# Write and save the file
with open('batch_results.jsonl', 'wb') as f:
f.write(output_file_stream.read())

列出批量作业

您可以查看您的批量作业列表,并根据各种标准进行过滤,包括

  • 状态:QUEUED(排队中)、RUNNING(运行中)、SUCCESS(成功)、FAILED(失败)、TIMEOUT_EXCEEDED(超时)、CANCELLATION_REQUESTED(已请求取消)和 CANCELLED(已取消)
  • 元数据:批量请求的自定义元数据键和值
list_job = client.batch.jobs.list(
status="RUNNING",
metadata={"job_type": "testing"}
)

请求取消批量作业

canceled_job = client.batch.jobs.cancel(job_id=created_job.id)

端到端示例

示例
import argparse
import json
import os
import random
import time
from io import BytesIO

import httpx
from mistralai import File, Mistral


def create_client():
"""
Create a Mistral client using the API key from environment variables.

Returns:
Mistral: An instance of the Mistral client.
"""
return Mistral(api_key=os.environ["MISTRAL_API_KEY"])

def generate_random_string(start, end):
"""
Generate a random string of variable length.

Args:
start (int): Minimum length of the string.
end (int): Maximum length of the string.

Returns:
str: A randomly generated string.
"""
length = random.randrange(start, end)
return ' '.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=length))

def print_stats(batch_job):
"""
Print the statistics of the batch job.

Args:
batch_job: The batch job object containing job statistics.
"""
print(f"Total requests: {batch_job.total_requests}")
print(f"Failed requests: {batch_job.failed_requests}")
print(f"Successful requests: {batch_job.succeeded_requests}")
print(
f"Percent done: {round((batch_job.succeeded_requests + batch_job.failed_requests) / batch_job.total_requests, 4) * 100}")


def create_input_file(client, num_samples):
"""
Create an input file for the batch job.

Args:
client (Mistral): The Mistral client instance.
num_samples (int): Number of samples to generate.

Returns:
File: The uploaded input file object.
"""
buffer = BytesIO()
for idx in range(num_samples):
request = {
"custom_id": str(idx),
"body": {
"max_tokens": random.randint(10, 1000),
"messages": [{"role": "user", "content": generate_random_string(100, 5000)}]
}
}
buffer.write(json.dumps(request).encode("utf-8"))
buffer.write("\n".encode("utf-8"))
return client.files.upload(file=File(file_name="file.jsonl", content=buffer.getvalue()), purpose="batch")


def run_batch_job(client, input_file, model):
"""
Run a batch job using the provided input file and model.

Args:
client (Mistral): The Mistral client instance.
input_file (File): The input file object.
model (str): The model to use for the batch job.

Returns:
BatchJob: The completed batch job object.
"""
batch_job = client.batch.jobs.create(
input_files=[input_file.id],
model=model,
endpoint="/v1/chat/completions",
metadata={"job_type": "testing"}
)

while batch_job.status in ["QUEUED", "RUNNING"]:
batch_job = client.batch.jobs.get(job_id=batch_job.id)
print_stats(batch_job)
time.sleep(1)

print(f"Batch job {batch_job.id} completed with status: {batch_job.status}")
return batch_job


def download_file(client, file_id, output_path):
"""
Download a file from the Mistral server.

Args:
client (Mistral): The Mistral client instance.
file_id (str): The ID of the file to download.
output_path (str): The path where the file will be saved.
"""
if file_id is not None:
print(f"Downloading file to {output_path}")
output_file = client.files.download(file_id=file_id)
with open(output_path, "w") as f:
for chunk in output_file.stream:
f.write(chunk.decode("utf-8"))
print(f"Downloaded file to {output_path}")


def main(num_samples, success_path, error_path, model):
"""
Main function to run the batch job.

Args:
num_samples (int): Number of samples to process.
success_path (str): Path to save successful outputs.
error_path (str): Path to save error outputs.
model (str): Model name to use.
"""
client = create_client()
input_file = create_input_file(client, num_samples)
print(f"Created input file {input_file}")

batch_job = run_batch_job(client, input_file, model)
print(f"Job duration: {batch_job.completed_at - batch_job.created_at} seconds")
download_file(client, batch_job.error_file, error_path)
download_file(client, batch_job.output_file, success_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run Mistral AI batch job")
parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to process")
parser.add_argument("--success_path", type=str, default="output.jsonl", help="Path to save successful outputs")
parser.add_argument("--error_path", type=str, default="error.jsonl", help="Path to save error outputs")
parser.add_argument("--model", type=str, default="codestral-latest", help="Model name to use")

args = parser.parse_args()

main(args.num_samples, args.success_path, args.error_path, args.model)

常见问题解答

批量 API 是否适用于所有模型?

是的,批量 API 适用于所有模型,包括用户微调的模型。

批量 API 是否影响定价?

批量 API 提供价格折扣。详情请参阅我们的定价页面

批量 API 是否影响速率限制?

一个批量请求中最大请求数是多少?

目前,每个工作空间最多可有 100 万个待处理请求。这意味着您不能提交一个超过 100 万个请求的作业。此外,您不能同时提交两个各包含 60 万个请求的作业。您需要等待第一个作业处理至少 20 万个请求,将其待处理数降至 40 万个。此时,包含 60 万个请求的新作业即可符合限制。

一个人最多可以创建多少批量作业?

目前没有最大限制。

批量 API 处理需要多长时间?

处理速度可能会根据当前需求和您的请求量进行调整。您的批量结果只有在整个批量处理完成后才能访问。

用户在创建作业时可以设置 timeout_hours,指定作业应在多少小时后过期。该值默认为 24 小时,且应低于 7 天。如果在指定超时时间内未完成处理,批量作业将过期。

我可以在我的工作空间中查看批量结果吗?

是的,批量请求是工作空间特定的。您可以在与您的 API 密钥关联的工作空间中查看所有创建的批量请求及其结果。

批量结果会过期吗?

目前,结果不会过期。

批量请求会超出消费限额吗?

是的,由于高吞吐量和并发处理,批量请求可能会略微超出您的工作空间配置的消费限额。