diff --git a/configs/agents/api_agents.yaml b/configs/agents/api_agents.yaml index c54ab46..25f70c2 100644 --- a/configs/agents/api_agents.yaml +++ b/configs/agents/api_agents.yaml @@ -6,26 +6,8 @@ gpt-4o-2024-05-13: model: "gpt-4o-2024-05-13" max_tokens: 512 -gpt-3.5-turbo-0613: - import: "./openai-chat.yaml" +glm-4v: + import: "./finetuned_agent.yaml" parameters: - name: "gpt-3.5-turbo-0613" - body: - model: "gpt-3.5-turbo-0613" - max_tokens: 512 - -text-davinci-003: - import: "./openai-text.yaml" - parameters: - name: "text-davinci-003" - body: - model: "text-davinci-003" - max_tokens: 512 - -text-davinci-002: - import: "./openai-text.yaml" - parameters: - name: "text-davinci-002" - body: - model: "text-davinci-002" - max_tokens: 512 + name: "glm-4v" + url: "http://" diff --git a/configs/agents/finetuned_agent.yaml b/configs/agents/finetuned_agent.yaml new file mode 100644 index 0000000..7a72691 --- /dev/null +++ b/configs/agents/finetuned_agent.yaml @@ -0,0 +1,9 @@ +module: "src.client.agents.HTTPAgent" +parameters: + prompter: + name: "prompt_string" + args: + suffix: "<|assistant|>\n" + agent_format: "<|assistant|>\n{content}\n\n" + text_context_limit: 9776 + return_format: "{response[response]}" \ No newline at end of file diff --git a/src/client/agents/http_agent.py b/src/client/agents/http_agent.py index 17bd35e..9f0f18a 100644 --- a/src/client/agents/http_agent.py +++ b/src/client/agents/http_agent.py @@ -1,7 +1,7 @@ import contextlib import time import warnings - +import math import requests from urllib3.exceptions import InsecureRequestWarning @@ -102,26 +102,78 @@ class Prompter: @staticmethod def prompt_string( prefix: str = "", - suffix: str = "AGENT:", - system_format: str = "SYSTEM: {content}\n\n###\n\n", - user_format: str = "USER: {content}\n\n", - agent_format: str = "AGENT: {content}\n\n", + suffix: str = "<|agent|>\n", + system_format: str = "<|system|>\n{content}\n\n", + user_format: str = "<|user|>\n{content}\n\n", + agent_format: str = "<|agent|>\n{content}\n\n", prompt_key: str = "prompt", + image_key: str = "image", + text_context_limit: int = 9776 ): - # todo (YG): current it assumes the content is always a string, but it can be a list for multimodal support - def prompter(messages: List[Dict[str, str]]): - nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key - prompt = prefix - for item in messages: + def prompter(messages: List[Dict]): + nonlocal prefix, suffix, system_format, user_format, agent_format, prompt_key, image_key, text_context_limit + + def text_token_estimation(text: str): + token_count = 0 + text = text.replace("\n\n", "\n").replace("```", "`").replace("##", "#") + for char in text: + if not char.isalnum() and char != " ": + token_count += 1 + words = text.split() + for word in words: + char_count = 0 + for char in word: + if char.isalnum(): + char_count += 1 + token_count += math.ceil(char_count / 6) + return token_count + + def item_to_prompt(item: Dict): + prompt, images = "", [] if item["role"] == "system": prompt += system_format.format(content=item["content"]) elif item["role"] == "user": - prompt += user_format.format(content=item["content"]) + if isinstance(item["content"], str): + prompt += user_format.format(content=item["content"]) + elif isinstance(item["content"], list): + text_str = "" + for content in item["content"]: + if content["type"] == "text": + text_str += content["text"] + else: + images.append(content["image_url"]["url"].split("file://")[-1]) + prompt += user_format.format(content=text_str) else: prompt += agent_format.format(content=item["content"]) + return prompt, images + + prompt = prefix + images = [] + for item in messages[:5]: + item_prompt, item_images = item_to_prompt(item) + prompt += item_prompt + images += item_images + if len(messages) > 5: + prompt_tail, images_tail = item_to_prompt(messages[-1]) + for index in range(len(messages)-2, 4, -2): + agent_item = messages[index] + agent_prompt, agent_images = item_to_prompt(agent_item) + user_item = messages[index-1] + user_prompt, user_images = item_to_prompt(user_item) + if text_token_estimation(prompt + user_prompt + agent_prompt + prompt_tail) > text_context_limit: + prompt_tail = "\n\n** Earlier trajectory has been truncated **\n\n" + prompt_tail + break + prompt_tail = user_prompt + agent_prompt + prompt_tail + images_tail = user_images + agent_images + images_tail + prompt += prompt_tail + images += images_tail prompt += suffix - print(prompt) - return {prompt_key: prompt} + if len(images) != 1: + raise Exception("Only one image is supported") + return [ + {prompt_key: prompt}, + {image_key: open(images[0], "rb")} + ] return prompter @@ -175,6 +227,7 @@ class HTTPAgent(AgentClient): self, url, proxies=None, + data=None, body=None, headers=None, return_format="{response}", @@ -185,9 +238,11 @@ class HTTPAgent(AgentClient): self.url = url self.proxies = proxies or {} self.headers = headers or {} + self.data = data or {} self.body = body or {} self.return_format = return_format self.prompter = Prompter.get_prompter(prompter) + self.prompter_type = prompter.get("name", "role_content_dict") if not self.url: raise Exception("Please set 'url' parameter") @@ -195,15 +250,27 @@ class HTTPAgent(AgentClient): return self.prompter(history) def inference(self, history: List[dict]) -> str: - history = replace_image_url(history, keep_path=False, throw_details=False) + if self.prompter_type == "role_content_dict": + history = replace_image_url(history, keep_path=False, throw_details=False) + else: + history = replace_image_url(history, keep_path=True, throw_details=True) for _ in range(5): try: - body = self.body.copy() - body.update(self._handle_history(history)) - with no_ssl_verification(): - resp = requests.post( - self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=120 - ) + if self.prompter_type == "role_content_dict": + body = self.body.copy() + body.update(self._handle_history(history)) + with no_ssl_verification(): + resp = requests.post( + self.url, json=body, headers=self.headers, proxies=self.proxies, timeout=180 + ) + else: + messages = self._handle_history(history) + data = self.data.copy() + data.update(messages[0]) + with no_ssl_verification(): + resp = requests.post( + self.url, data=data, files=messages[1], headers=self.headers, proxies=self.proxies, timeout=180 + ) # print(resp.status_code, resp.text) if resp.status_code != 200: # print(resp.text)