291 lines
8.6 KiB
Python
291 lines
8.6 KiB
Python
"""Tools to generate from OpenAI prompts.
|
|
Adopted from https://github.com/zeno-ml/zeno-build/"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import random
|
|
import time
|
|
from typing import Any
|
|
|
|
import aiolimiter
|
|
import openai
|
|
from openai import AsyncOpenAI, OpenAI
|
|
|
|
base_url = os.environ.get("OPENAI_API_URL")
|
|
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
|
|
aclient = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
|
|
def retry_with_exponential_backoff( # type: ignore
|
|
func,
|
|
initial_delay: float = 1,
|
|
exponential_base: float = 2,
|
|
jitter: bool = True,
|
|
max_retries: int = 3,
|
|
errors: tuple[Any] = (
|
|
openai.RateLimitError,
|
|
openai.BadRequestError,
|
|
openai.InternalServerError,
|
|
),
|
|
):
|
|
"""Retry a function with exponential backoff."""
|
|
|
|
def wrapper(*args, **kwargs): # type: ignore
|
|
# Initialize variables
|
|
num_retries = 0
|
|
delay = initial_delay
|
|
# Loop until a successful response or max_retries is hit or an exception is raised
|
|
while True:
|
|
try:
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
# Retry on specified errors
|
|
except errors as e:
|
|
# Increment retries
|
|
num_retries += 1
|
|
|
|
# Check if max retries has been reached
|
|
if num_retries > max_retries:
|
|
raise Exception(
|
|
f"Maximum number of retries ({max_retries}) exceeded."
|
|
)
|
|
|
|
# Increment the delay
|
|
delay *= exponential_base * (1 + jitter * random.random())
|
|
|
|
# Sleep for the delay
|
|
time.sleep(delay)
|
|
|
|
# Raise exceptions for any errors not specified
|
|
except Exception as e:
|
|
raise e
|
|
|
|
return wrapper
|
|
|
|
|
|
async def _throttled_openai_completion_acreate(
|
|
engine: str,
|
|
prompt: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
limiter: aiolimiter.AsyncLimiter,
|
|
) -> dict[str, Any]:
|
|
async with limiter:
|
|
for _ in range(3):
|
|
try:
|
|
return await aclient.completions.create(
|
|
engine=engine,
|
|
prompt=prompt,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
)
|
|
except openai.RateLimitError:
|
|
logging.warning(
|
|
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
|
|
)
|
|
await asyncio.sleep(10)
|
|
except openai.APIError as e:
|
|
logging.warning(f"OpenAI API error: {e}")
|
|
break
|
|
return {"choices": [{"message": {"content": ""}}]}
|
|
|
|
|
|
async def agenerate_from_openai_completion(
|
|
prompts: list[str],
|
|
engine: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
context_length: int,
|
|
requests_per_minute: int = 300,
|
|
) -> list[str]:
|
|
"""Generate from OpenAI Completion API.
|
|
|
|
Args:
|
|
prompts: list of prompts
|
|
temperature: Temperature to use.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
top_p: Top p to use.
|
|
context_length: Length of context to use.
|
|
requests_per_minute: Number of requests per minute to allow.
|
|
|
|
Returns:
|
|
List of generated responses.
|
|
"""
|
|
if "OPENAI_API_KEY" not in os.environ:
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
|
)
|
|
|
|
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
|
|
async_responses = [
|
|
_throttled_openai_completion_acreate(
|
|
engine=engine,
|
|
prompt=prompt,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
limiter=limiter,
|
|
)
|
|
for prompt in prompts
|
|
]
|
|
responses = await tqdm_asyncio.gather(*async_responses)
|
|
return [x["choices"][0]["text"] for x in responses]
|
|
|
|
|
|
@retry_with_exponential_backoff
|
|
def generate_from_openai_completion(
|
|
prompt: str,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
stop_token: str | None = None,
|
|
api_key: str | None = None,
|
|
base_url: str | None = None
|
|
) -> str:
|
|
if "OPENAI_API_KEY" not in os.environ and api_key is None:
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
|
)
|
|
if api_key is not None:
|
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
response = client.completions.create(
|
|
prompt=prompt,
|
|
model=model,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
stop=[stop_token],
|
|
)
|
|
try:
|
|
answer: str = response["choices"][0]["text"]
|
|
except:
|
|
answer: str = response.choices[0].text
|
|
return answer
|
|
|
|
|
|
async def _throttled_openai_chat_completion_acreate(
|
|
model: str,
|
|
messages: list[dict[str, str]],
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
limiter: aiolimiter.AsyncLimiter,
|
|
) -> dict[str, Any]:
|
|
async with limiter:
|
|
for _ in range(3):
|
|
try:
|
|
return await aclient.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
)
|
|
except openai.RateLimitError:
|
|
logging.warning(
|
|
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
|
|
)
|
|
await asyncio.sleep(10)
|
|
except asyncio.exceptions.TimeoutError:
|
|
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
|
|
await asyncio.sleep(10)
|
|
except openai.APIError as e:
|
|
logging.warning(f"OpenAI API error: {e}")
|
|
break
|
|
return {"choices": [{"message": {"content": ""}}]}
|
|
|
|
|
|
async def agenerate_from_openai_chat_completion(
|
|
messages_list: list[list[dict[str, str]]],
|
|
engine: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
context_length: int,
|
|
requests_per_minute: int = 300,
|
|
) -> list[str]:
|
|
"""Generate from OpenAI Chat Completion API.
|
|
|
|
Args:
|
|
messages_list: list of message list
|
|
temperature: Temperature to use.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
top_p: Top p to use.
|
|
context_length: Length of context to use.
|
|
requests_per_minute: Number of requests per minute to allow.
|
|
|
|
Returns:
|
|
List of generated responses.
|
|
"""
|
|
if "OPENAI_API_KEY" not in os.environ:
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
|
)
|
|
|
|
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
|
|
async_responses = [
|
|
_throttled_openai_chat_completion_acreate(
|
|
model=engine,
|
|
messages=message,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
limiter=limiter,
|
|
)
|
|
for message in messages_list
|
|
]
|
|
responses = await tqdm_asyncio.gather(*async_responses)
|
|
return [x["choices"][0]["message"]["content"] for x in responses]
|
|
|
|
|
|
@retry_with_exponential_backoff
|
|
def generate_from_openai_chat_completion(
|
|
messages: list[dict[str, str]],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
context_length: int,
|
|
stop_token: str | None = None,
|
|
) -> str:
|
|
if "OPENAI_API_KEY" not in os.environ:
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
|
)
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
)
|
|
answer: str = response.choices[0].message.content
|
|
return answer
|
|
|
|
|
|
@retry_with_exponential_backoff
|
|
# debug only
|
|
def fake_generate_from_openai_chat_completion(
|
|
messages: list[dict[str, str]],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
context_length: int,
|
|
stop_token: str | None = None,
|
|
) -> str:
|
|
if "OPENAI_API_KEY" not in os.environ:
|
|
raise ValueError(
|
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
|
)
|
|
|
|
answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"."
|
|
return answer
|