"""Tools to generate from OpenAI prompts. Adopted from https://github.com/zeno-ml/zeno-build/""" import asyncio import logging import os import random import time from typing import Any import aiolimiter import openai from openai import AsyncOpenAI, OpenAI base_url = os.environ.get("OPENAI_API_URL") client = OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url) aclient = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url) from tqdm.asyncio import tqdm_asyncio def retry_with_exponential_backoff( # type: ignore func, initial_delay: float = 1, exponential_base: float = 2, jitter: bool = True, max_retries: int = 3, errors: tuple[Any] = ( openai.RateLimitError, openai.BadRequestError, openai.InternalServerError, ), ): """Retry a function with exponential backoff.""" def wrapper(*args, **kwargs): # type: ignore # Initialize variables num_retries = 0 delay = initial_delay # Loop until a successful response or max_retries is hit or an exception is raised while True: try: return func(*args, **kwargs) # Retry on specified errors except errors as e: # Increment retries num_retries += 1 # Check if max retries has been reached if num_retries > max_retries: raise Exception( f"Maximum number of retries ({max_retries}) exceeded." ) # Increment the delay delay *= exponential_base * (1 + jitter * random.random()) # Sleep for the delay time.sleep(delay) # Raise exceptions for any errors not specified except Exception as e: raise e return wrapper async def _throttled_openai_completion_acreate( engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, limiter: aiolimiter.AsyncLimiter, ) -> dict[str, Any]: async with limiter: for _ in range(3): try: return await aclient.completions.create( engine=engine, prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_p=top_p, ) except openai.RateLimitError: logging.warning( "OpenAI API rate limit exceeded. Sleeping for 10 seconds." ) await asyncio.sleep(10) except openai.APIError as e: logging.warning(f"OpenAI API error: {e}") break return {"choices": [{"message": {"content": ""}}]} async def agenerate_from_openai_completion( prompts: list[str], engine: str, temperature: float, max_tokens: int, top_p: float, context_length: int, requests_per_minute: int = 300, ) -> list[str]: """Generate from OpenAI Completion API. Args: prompts: list of prompts temperature: Temperature to use. max_tokens: Maximum number of tokens to generate. top_p: Top p to use. context_length: Length of context to use. requests_per_minute: Number of requests per minute to allow. Returns: List of generated responses. """ if "OPENAI_API_KEY" not in os.environ: raise ValueError( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) limiter = aiolimiter.AsyncLimiter(requests_per_minute) async_responses = [ _throttled_openai_completion_acreate( engine=engine, prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_p=top_p, limiter=limiter, ) for prompt in prompts ] responses = await tqdm_asyncio.gather(*async_responses) return [x["choices"][0]["text"] for x in responses] @retry_with_exponential_backoff def generate_from_openai_completion( prompt: str, engine: str, temperature: float, max_tokens: int, top_p: float, context_length: int, stop_token: str | None = None, ) -> str: if "OPENAI_API_KEY" not in os.environ: raise ValueError( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) response = client.completions.create( prompt=prompt, engine=engine, temperature=temperature, max_tokens=max_tokens, top_p=top_p, stop=[stop_token], ) answer: str = response["choices"][0]["text"] return answer async def _throttled_openai_chat_completion_acreate( model: str, messages: list[dict[str, str]], temperature: float, max_tokens: int, top_p: float, limiter: aiolimiter.AsyncLimiter, ) -> dict[str, Any]: async with limiter: for _ in range(3): try: return await aclient.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p, ) except openai.RateLimitError: logging.warning( "OpenAI API rate limit exceeded. Sleeping for 10 seconds." ) await asyncio.sleep(10) except asyncio.exceptions.TimeoutError: logging.warning("OpenAI API timeout. Sleeping for 10 seconds.") await asyncio.sleep(10) except openai.APIError as e: logging.warning(f"OpenAI API error: {e}") break return {"choices": [{"message": {"content": ""}}]} async def agenerate_from_openai_chat_completion( messages_list: list[list[dict[str, str]]], engine: str, temperature: float, max_tokens: int, top_p: float, context_length: int, requests_per_minute: int = 300, ) -> list[str]: """Generate from OpenAI Chat Completion API. Args: messages_list: list of message list temperature: Temperature to use. max_tokens: Maximum number of tokens to generate. top_p: Top p to use. context_length: Length of context to use. requests_per_minute: Number of requests per minute to allow. Returns: List of generated responses. """ if "OPENAI_API_KEY" not in os.environ: raise ValueError( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) limiter = aiolimiter.AsyncLimiter(requests_per_minute) async_responses = [ _throttled_openai_chat_completion_acreate( model=engine, messages=message, temperature=temperature, max_tokens=max_tokens, top_p=top_p, limiter=limiter, ) for message in messages_list ] responses = await tqdm_asyncio.gather(*async_responses) return [x["choices"][0]["message"]["content"] for x in responses] @retry_with_exponential_backoff def generate_from_openai_chat_completion( messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int, top_p: float, context_length: int, stop_token: str | None = None, ) -> str: if "OPENAI_API_KEY" not in os.environ: raise ValueError( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) response = client.chat.completions.create( model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, top_p=top_p, ) answer: str = response.choices[0].message.content return answer @retry_with_exponential_backoff # debug only def fake_generate_from_openai_chat_completion( messages: list[dict[str, str]], model: str, temperature: float, max_tokens: int, top_p: float, context_length: int, stop_token: str | None = None, ) -> str: if "OPENAI_API_KEY" not in os.environ: raise ValueError( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"." return answer