webrl/VAB-WebArena-Lite/new/openai_utils.py
2024-10-20 00:10:34 +08:00

287 lines
8.4 KiB
Python

"""Tools to generate from OpenAI prompts.
Adopted from https://github.com/zeno-ml/zeno-build/"""
import asyncio
import logging
import os
import random
import time
from typing import Any
import aiolimiter
import openai
from openai import AsyncOpenAI, OpenAI
base_url = os.environ.get("OPENAI_API_URL")
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
aclient = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
from tqdm.asyncio import tqdm_asyncio
def retry_with_exponential_backoff( # type: ignore
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 3,
errors: tuple[Any] = (
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
),
):
"""Retry a function with exponential backoff."""
def wrapper(*args, **kwargs): # type: ignore
# Initialize variables
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
# Retry on specified errors
except errors as e:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
# Sleep for the delay
time.sleep(delay)
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
async def _throttled_openai_completion_acreate(
engine: str,
prompt: str,
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await aclient.completions.create(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except openai.APIError as e:
logging.warning(f"OpenAI API error: {e}")
break
return {"choices": [{"message": {"content": ""}}]}
async def agenerate_from_openai_completion(
prompts: list[str],
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 300,
) -> list[str]:
"""Generate from OpenAI Completion API.
Args:
prompts: list of prompts
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_completion_acreate(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for prompt in prompts
]
responses = await tqdm_asyncio.gather(*async_responses)
return [x["choices"][0]["text"] for x in responses]
@retry_with_exponential_backoff
def generate_from_openai_completion(
prompt: str,
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
response = client.completions.create(
prompt=prompt,
engine=engine,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=[stop_token],
)
answer: str = response["choices"][0]["text"]
return answer
async def _throttled_openai_chat_completion_acreate(
model: str,
messages: list[dict[str, str]],
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await aclient.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except asyncio.exceptions.TimeoutError:
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
await asyncio.sleep(10)
except openai.APIError as e:
logging.warning(f"OpenAI API error: {e}")
break
return {"choices": [{"message": {"content": ""}}]}
async def agenerate_from_openai_chat_completion(
messages_list: list[list[dict[str, str]]],
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 300,
) -> list[str]:
"""Generate from OpenAI Chat Completion API.
Args:
messages_list: list of message list
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_chat_completion_acreate(
model=engine,
messages=message,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for message in messages_list
]
responses = await tqdm_asyncio.gather(*async_responses)
return [x["choices"][0]["message"]["content"] for x in responses]
@retry_with_exponential_backoff
def generate_from_openai_chat_completion(
messages: list[dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
answer: str = response.choices[0].message.content
return answer
@retry_with_exponential_backoff
# debug only
def fake_generate_from_openai_chat_completion(
messages: list[dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"."
return answer