add openai LMMs evaluation

This commit is contained in:
mistyreed63849 2024-08-28 07:40:38 +08:00
parent 7e334299f6
commit 8cac24e27d
15 changed files with 134 additions and 37 deletions

View File

@ -1,6 +1,6 @@
module: src.client.agents.HTTPAgent module: src.client.agents.HTTPAgent
parameters: parameters:
url: https://one-api.glm.ai/v1/chat/completions # https://api.openai.com/v1/chat/completions url: https://api.openai.com/v1/chat/completions
headers: headers:
Content-Type: application/json Content-Type: application/json
Authorization: Bearer Authorization: Bearer

View File

@ -3,22 +3,20 @@ import: definition.yaml
concurrency: concurrency:
task: task:
# css-std: 5 # css-std: 5
omnigibson-std: 6 # omnigibson-std: 4
# minecraft-std: 6 minecraft-std: 4
agent: agent:
# gpt-3.5-turbo-0613: 1 # use this gpt-4o-2024-05-13: 4
gpt-4o-2024-05-13: 6
assignments: # List[Assignment] | Assignment assignments: # List[Assignment] | Assignment
- agent: # "task": List[str] | str , "agent": List[str] | str - agent: # "task": List[str] | str , "agent": List[str] | str
# - gpt-3.5-turbo-0613
- gpt-4o-2024-05-13 - gpt-4o-2024-05-13
task: task:
# - css-std # - css-std
- omnigibson-std # - omnigibson-std
# - minecraft-std - minecraft-std
# output: "outputs/{TIMESTAMP}" # output: "outputs/{TIMESTAMP}"
# output: "outputs/css_test" # output: "outputs/css_test"
output: "outputs/omnigibson" # output: "outputs/omnigibson"
# output: "outputs/minecraft" output: "outputs/minecraft"

View File

@ -3,5 +3,5 @@ definition:
start: start:
# css-std: 1 # css-std: 1
omnigibson-std: 6 # omnigibson-std: 4
# minecraft-std: 6 minecraft-std: 4

View File

@ -31,3 +31,4 @@ minecraft-std:
"7": 2 "7": 2
data_dir: "data/minecraft" data_dir: "data/minecraft"
output_dir: "outputs/minecraft" output_dir: "outputs/minecraft"
docker_image: "tianjiezhang/vab_minecraft:latest"

View File

@ -23,3 +23,4 @@ omnigibson-std:
"7": 1 "7": 1
data_dir: "data/omnigibson" data_dir: "data/omnigibson"
output_dir: "outputs/omnigibson" output_dir: "outputs/omnigibson"
docker_image: "tianjiezhang/vab_omnigibson:latest"

View File

@ -0,0 +1,11 @@
mkdir -p data/minecraft/mineclip
mkdir -p data/minecraft/steve1
mkdir -p data/minecraft/vpt
wget https://openaipublic.blob.core.windows.net/minecraft-rl/models/2x.model -P data/minecraft/vpt
curl -L -o data/minecraft/mineclip/attn.pth "https://drive.google.com/uc?export=download&id=1uaZM1ZLBz2dZWcn85rZmjP7LV6Sg5PZW"
curl -L -o data/minecraft/steve1/steve1.weights "https://drive.google.com/uc?id=1E3fd_-H1rRZqMkUKHfiMhx-ppLLehQPI"
curl -L -o data/weights/steve1/steve1_prior.pt "https://drive.google.com/uc?id=1OdX5wiybK8jALVfP5_dEo0CWm9BQbDES"

View File

@ -0,0 +1,72 @@
import os
import subprocess
def print_user_agreement():
print('You are downloading dataset from Behavior-1K and OmniGibson. Here is the user agreement of OmniGibson 0.2.1\n'
'\n\nBEHAVIOR DATA BUNDLE END USER LICENSE AGREEMENT\n'
'Last revision: December 8, 2022\n'
'This License Agreement is for the BEHAVIOR Data Bundle (“Data”). It works with OmniGibson (“Software”) which is a software stack licensed under the MIT License, provided in this repository: https://github.com/StanfordVL/OmniGibson. The license agreements for OmniGibson and the Data are independent. This BEHAVIOR Data Bundle contains artwork and images (“Third Party Content”) from third parties with restrictions on redistribution. It requires measures to protect the Third Party Content which we have taken such as encryption and the inclusion of restrictions on any reverse engineering and use. Recipient is granted the right to use the Data under the following terms and conditions of this License Agreement (“Agreement”):\n\n'
'1. Use of the Data is permitted after responding "Yes" to this agreement. A decryption key will be installed automatically.\n'
'2. Data may only be used for non-commercial academic research. You may not use a Data for any other purpose.\n'
'3. The Data has been encrypted. You are strictly prohibited from extracting any Data from OmniGibson or reverse engineering.\n'
'4. You may only use the Data within OmniGibson.\n'
'5. You may not redistribute the key or any other Data or elements in whole or part.\n'
'6. THE DATA AND SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE DATA OR SOFTWARE OR THE USE OR OTHER DEALINGS IN THE DATA OR SOFTWARE.\n\n')
def print_omniverse_agreement():
print("The NVIDIA Omniverse License Agreement (EULA) must be accepted before\n"
"Omniverse Kit can start. The license terms for this product can be viewed at\n"
"https://docs.omniverse.nvidia.com/app_isaacsim/common/NVIDIA_Omniverse_License_Agreement.html\n")
def main():
if not os.path.exists("data/omnigibson"):
subprocess.run(["mkdir", "-p", "data/omnigibson"])
if not os.path.exists("data/omnigibson/activity_definitions"):
subprocess.run(["git", "clone", "https://www.modelscope.cn/datasets/VisualAgentBench/VAB-OmniGibson-Data.git", "data/omnigibson"])
if not os.path.exists("data/omnigibson/GoogleNews-vectors-negative300.bin"):
subprocess.run(["wget", "https://huggingface.co/LoganKilpatrick/GoogleNews-vectors-negative300/resolve/main/GoogleNews-vectors-negative300.bin.gz", "-O", "data/omnigibson/GoogleNews-vectors-negative300.bin.gz"])
subprocess.run(["gzip", "-d", "data/omnigibson/GoogleNews-vectors-negative300.bin.gz"])
if not os.path.exists("data/omnigibson/isaac-sim"):
subprocess.run(["mkdir", "-p", "data/omnigibson/isaac-sim"])
if not os.path.exists("data/omnigibson/datasets"):
subprocess.run(["mkdir", "-p", "data/omnigibson/datasets"])
if not os.path.exists("data/omnigibson/datasets/omnigibson.key"):
_ = ((()==())+(()==()));__=(((_<<_)<<_)*_);___=('c%'[::(([]!=[])-(()==()))])*(((_<<_)<<_)+(((_<<_)*_)+((_<<_)+(_+(()==())))))%((__+(((_<<_)<<_)+(_<<_))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_*_)))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_*_)))),(__+(((_<<_)<<_)+((_<<_)*_))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(((_<<_)<<_)+(((_<<_)*_)+((_<<_)+_))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_*_)))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+(((_<<_)*_)+_))),(__+(((_<<_)<<_)+(()==()))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(((_<<_)<<_)+((_<<_)+((_*_)+_))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+(_*_)))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+(()==()))),(__+(((_<<_)<<_)+((_<<_)*_))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+_))),(__+(((_<<_)<<_)+(_+(()==())))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+_)),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(__+(((_<<_)*_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+(_+(()==())))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+((_*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+(()==())))),(__+(((_<<_)<<_)+_)),(__+(((_<<_)<<_)+(((_<<_)*_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+(_+(()==())))))),(__+(((_<<_)<<_)+((_<<_)+((_*_)+_)))),(((_<<_)<<_)+((_<<_)+((_*_)+_))),(__+(((_<<_)<<_)+((_<<_)+(_+(()==()))))),(__+(((_<<_)<<_)+((_*_)+(()==())))),(__+(((_<<_)<<_)+(((_<<_)*_)+((_<<_)+(()==()))))))
path = ___
subprocess.run(["wget", path, "-O", "data/omnigibson/datasets/omnigibson.key"])
if not os.path.exists("data/omnigibson/datasets/og_dataset"):
print("\n")
print_user_agreement()
while (
input(
"Do you agree to the above terms for using OmniGibson dataset? [y/n]"
)
!= "y"
):
print("You need to agree to the terms for using OmniGibson dataset.")
subprocess.run(["wget", "https://storage.googleapis.com/gibson_scenes/og_dataset.tar.gz", "-O", "data/omnigibson/datasets/og_dataset.tar.gz"])
subprocess.run(["tar", "-xvf", "data/omnigibson/dataset/og_dataset.tar.gz", "-C", "data/omnigibson/dataset"])
if not os.path.exists("data/omnigibson/datasets/assets"):
subprocess.run(["wget", "https://storage.googleapis.com/gibson_scenes/og_assets.tar.gz", "-O", "data/omnigibson/datasets/og_assets.tar.gz"])
subprocess.run(["tar", "-xvf", "data/omnigibson/dataset/og_assets.tar.gz", "-C", "data/omnigibson/dataset"])
if os.path.exists("data/omnigibson/datasets/og_dataset/scenes/Beechwood_0_int/json/Beechwood_0_int_best.json"):
subprocess.run(["rm", "-rf", "data/omnigibson/datasets/og_dataset/scenes/*"])
subprocess.run(["mv", "data/omnigibson/vab_omnigibson_scenes/*", "data/omnigibson/datasets/og_dataset/scenes/"])
print_omniverse_agreement()
if __name__ == "__main__":
main()

View File

@ -196,7 +196,7 @@ class HTTPAgent(AgentClient):
def inference(self, history: List[dict]) -> str: def inference(self, history: List[dict]) -> str:
history = replace_image_url(history, keep_path=False, throw_details=False) history = replace_image_url(history, keep_path=False, throw_details=False)
for _ in range(3): for _ in range(5):
try: try:
body = self.body.copy() body = self.body.copy()
body.update(self._handle_history(history)) body.update(self._handle_history(history))

View File

@ -15,6 +15,7 @@ class Container():
data_dir = "data/minecraft" data_dir = "data/minecraft"
vab_source_dir = "" vab_source_dir = ""
max_round = 100 max_round = 100
docker_image = ""
def use_port(self): def use_port(self):
time.sleep(random.random()) time.sleep(random.random())
@ -83,7 +84,7 @@ class Container():
"PYTHONPYCACHEPREFIX": "/tmp" "PYTHONPYCACHEPREFIX": "/tmp"
} }
self.container = self.client.containers.run( self.container = self.client.containers.run(
"vab-minecraft:latest", self.docker_image,
f"xvfb-run -a python main.py --task_name {task} --port {self.port} --max_round {self.max_round}", f"xvfb-run -a python main.py --task_name {task} --port {self.port} --max_round {self.max_round}",
environment=environment, environment=environment,
volumes=volumes, volumes=volumes,
@ -101,6 +102,9 @@ class Container():
async def execute(self, session): async def execute(self, session):
while True: while True:
self.container.reload()
if self.container.status == "exited":
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR)
try: try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(("localhost", self.port)) s.connect(("localhost", self.port))
@ -207,7 +211,8 @@ class Container():
def close(self): def close(self):
try: try:
time.sleep(12) time.sleep(12)
self.container.stop(timeout=36) if self.container.status != "exited":
self.container.stop(timeout=24)
except: except:
pass pass

View File

@ -8,7 +8,7 @@ import json
import asyncio import asyncio
class Minecraft(Task): class Minecraft(Task):
def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, **configs): def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, docker_image, **configs):
super().__init__(**configs) super().__init__(**configs)
self.vab_source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vab_minecraft_src") self.vab_source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vab_minecraft_src")
Container.available_devices = available_devices Container.available_devices = available_devices
@ -29,6 +29,7 @@ class Minecraft(Task):
os.makedirs(Container.output_dir) os.makedirs(Container.output_dir)
Container.max_round = max_round Container.max_round = max_round
Container.vab_source_dir = self.vab_source_path Container.vab_source_dir = self.vab_source_path
Container.docker_image = docker_image
with open(f"{self.vab_source_path}/jarvis/assets/tasks/formated_tasks_v1.3.json", "r") as f: with open(f"{self.vab_source_path}/jarvis/assets/tasks/formated_tasks_v1.3.json", "r") as f:
tasks = json.load(f) tasks = json.load(f)
self.tasks = [t["task"] for t in tasks] self.tasks = [t["task"] for t in tasks]
@ -40,14 +41,14 @@ class Minecraft(Task):
try: try:
container = Container(self.tasks[index]) container = Container(self.tasks[index])
session.clear() session.clear()
session.inject({"role": "user", "content": SYSTEM_PROMPT}) session.inject({"role": "system", "content": SYSTEM_PROMPT})
while True: while True:
result = await container.execute(session) result = await container.execute(session)
if result.status != SampleStatus.RUNNING: if result.status != SampleStatus.RUNNING:
return result return result
except Exception as e: except Exception as e:
print(e) print(e)
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR) return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR, result={"error": e})
finally: finally:
try: try:
container.close() container.close()
@ -58,28 +59,29 @@ class Minecraft(Task):
average_reward = 0 average_reward = 0
success_count = 0 success_count = 0
for result in results: for result in results:
if result.result: if isinstance(result.result, Dict) and "success" in result.result:
if result.result["success"]: if result.result["success"]:
success_count += 1 success_count += 1
average_reward += 1 average_reward += 1
else: else:
initial_reward = result.result["initial_reward"]
final_reward = result.result["final_reward"] final_reward = result.result["final_reward"]
reward = (final_reward - initial_reward) / (1.0 - initial_reward)
average_reward += max(0, final_reward) average_reward += max(0, final_reward)
return { return {
"total_count": len(results),
"success_count": success_count,
"success_rate": success_count / len(results), "success_rate": success_count / len(results),
"average_reward": average_reward / len(results) "average_reward": average_reward / len(results)
} }
async def main(): async def main():
available_ports = [12000, 12001, 12002] available_ports = [11000, 11001, 11002]
available_devices = {"0":6, "1":6, "2":6} available_devices = {"9":6, "1":6, "2":6}
data_dir = "data/minecraft" data_dir = "data/minecraft"
output_dir = "outputs/minecraft" output_dir = "outputs/minecraft"
max_round = 100 max_round = 100
task = Minecraft(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, name="Minecraft-std") docker_image = "tianjiezhang/vab_minecraft:latest"
task = Minecraft(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, docker_image=docker_image, name="Minecraft-std")
print(Container.available_devices) print(Container.available_devices)
print(Container.available_ports) print(Container.available_ports)
session = Session() session = Session()

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import os import os
from functools import lru_cache from functools import lru_cache
from jarvis.steveI.path import OPENAI_CLIP_PATH
# disable HuggingFace warning # disable HuggingFace warning
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -17,7 +16,7 @@ def get_model_full_name(model_name):
elif model_name in ["distilbert", "distilbert-base-uncased"]: elif model_name in ["distilbert", "distilbert-base-uncased"]:
model_full_name = "distilbert-base-uncased" model_full_name = "distilbert-base-uncased"
elif model_name in ["clip", "openai/clip-vit-base-patch16"]: elif model_name in ["clip", "openai/clip-vit-base-patch16"]:
model_full_name = OPENAI_CLIP_PATH model_full_name = "openai/clip-vit-base-patch16"
else: else:
model_full_name = model_name model_full_name = model_name
return model_full_name return model_full_name

View File

@ -4,4 +4,3 @@ VPT_MODEL_PATH = Path(__file__).parent / "weights" / "vpt" / "2x.model"
VPT_WEIGHT_PATH = Path(__file__).parent / "weights" / "steve1" / "steve1.weights" VPT_WEIGHT_PATH = Path(__file__).parent / "weights" / "steve1" / "steve1.weights"
PRIOR_WEIGHT_PATH = Path(__file__).parent / "weights" / "steve1" / "steve1_prior.pt" PRIOR_WEIGHT_PATH = Path(__file__).parent / "weights" / "steve1" / "steve1_prior.pt"
MINECLIP_WEIGHT_PATH = Path(__file__).parent / "weights" / "mineclip" / "attn.pth" MINECLIP_WEIGHT_PATH = Path(__file__).parent / "weights" / "mineclip" / "attn.pth"
OPENAI_CLIP_PATH = Path(__file__).parent / "weights" / "clip-vit-base-patch16"

View File

@ -23,6 +23,7 @@ class Container():
vab_source_dir = "" vab_source_dir = ""
modified_omnigibson_src = "" modified_omnigibson_src = ""
max_round = 100 max_round = 100
docker_image = ""
def use_port(self): def use_port(self):
time.sleep(random.random()) time.sleep(random.random())
@ -146,7 +147,7 @@ class Container():
"OMNIGIBSON_HEADLESS": "1" "OMNIGIBSON_HEADLESS": "1"
} }
self.container = self.client.containers.run( self.container = self.client.containers.run(
"vab-omnigibson:latest", self.docker_image,
f"python main.py --task {task[0]} --scene {task[1]} --port {self.port} --max_round {self.max_round}", f"python main.py --task {task[0]} --scene {task[1]} --port {self.port} --max_round {self.max_round}",
environment=environment, environment=environment,
volumes=volumes, volumes=volumes,
@ -164,6 +165,9 @@ class Container():
async def execute(self, session): async def execute(self, session):
while True: while True:
self.container.reload()
if self.container.status == "exited":
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR)
try: try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(("localhost", self.port)) s.connect(("localhost", self.port))
@ -219,7 +223,7 @@ class Container():
] ]
} }
) )
print(text_prompt) # print(text_prompt)
message = await session.action() message = await session.action()
print(message) print(message)
if message.status == AgentOutputStatus.AGENT_CONTEXT_LIMIT: if message.status == AgentOutputStatus.AGENT_CONTEXT_LIMIT:
@ -252,6 +256,7 @@ class Container():
def close(self): def close(self):
try: try:
time.sleep(4) time.sleep(4)
if not self.container.status == "exit":
self.container.stop(timeout=24) self.container.stop(timeout=24)
except: except:
pass pass

View File

@ -7,7 +7,7 @@ import os
import asyncio import asyncio
class OmniGibson(Task): class OmniGibson(Task):
def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, **configs): def __init__(self, available_ports, available_devices, data_dir, output_dir, max_round, docker_image, **configs):
super().__init__(**configs) super().__init__(**configs)
self.vab_source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vab_omnigibson_src") self.vab_source_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vab_omnigibson_src")
modified_omnigibson_src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "modified_omnigibson_src") modified_omnigibson_src = os.path.join(os.path.dirname(os.path.abspath(__file__)), "modified_omnigibson_src")
@ -30,6 +30,7 @@ class OmniGibson(Task):
Container.max_round = max_round Container.max_round = max_round
Container.vab_source_dir = self.vab_source_path Container.vab_source_dir = self.vab_source_path
Container.modified_omnigibson_src = modified_omnigibson_src Container.modified_omnigibson_src = modified_omnigibson_src
Container.docker_image = docker_image
with open(f"{self.vab_source_path}/task/tasks.txt", "r") as f: with open(f"{self.vab_source_path}/task/tasks.txt", "r") as f:
tasks = f.read() tasks = f.read()
self.tasks = eval(tasks) self.tasks = eval(tasks)
@ -41,14 +42,14 @@ class OmniGibson(Task):
try: try:
container = Container(self.tasks[index]) container = Container(self.tasks[index])
session.clear() session.clear()
session.inject({"role": "user", "content": SYSTEM_PROMPT}) session.inject({"role": "system", "content": SYSTEM_PROMPT})
while True: while True:
result = await container.execute(session) result = await container.execute(session)
if result.status != SampleStatus.RUNNING: if result.status != SampleStatus.RUNNING:
return result return result
except Exception as e: except Exception as e:
print(e) print(e)
return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR) return TaskSampleExecutionResult(status=SampleStatus.TASK_ERROR, result={"error": e})
finally: finally:
try: try:
container.close() container.close()
@ -59,7 +60,7 @@ class OmniGibson(Task):
average_reward = 0 average_reward = 0
success_count = 0 success_count = 0
for result in results: for result in results:
if result.result: if isinstance(result.result, Dict) and "success" in result.result:
if result.result["success"]: if result.result["success"]:
success_count += 1 success_count += 1
average_reward += 1 average_reward += 1
@ -69,6 +70,8 @@ class OmniGibson(Task):
reward = (final_reward - initial_reward) / (1.0 - initial_reward) reward = (final_reward - initial_reward) / (1.0 - initial_reward)
average_reward += max(0, reward) average_reward += max(0, reward)
return { return {
"total_count": len(results),
"success_count": success_count,
"success_rate": success_count / len(results), "success_rate": success_count / len(results),
"average_reward": average_reward / len(results) "average_reward": average_reward / len(results)
} }
@ -76,11 +79,12 @@ class OmniGibson(Task):
async def main(): async def main():
available_ports = [12000, 12001, 12002] available_ports = [12000, 12001, 12002]
available_devices = {"0":1, "1":1, "2":1} available_devices = {"9":1, "1":1, "2":1}
data_dir = "data/omnigibson" data_dir = "data/omnigibson"
output_dir = "outputs/omnigibson" output_dir = "outputs/omnigibson"
max_round = 100 max_round = 100
task = OmniGibson(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, name="OmniGibson-std") docker_image = "vab_omnigibson:latest"
task = OmniGibson(available_ports=available_ports, available_devices=available_devices, max_round=max_round, data_dir=data_dir, output_dir=output_dir, docker_image=docker_image, name="OmniGibson-std")
print(Container.available_devices) print(Container.available_devices)
print(Container.available_ports) print(Container.available_ports)
session = Session() session = Session()

View File

@ -41,6 +41,6 @@ class Assignment(BaseModel):
class ChatHistoryItem(BaseModel): class ChatHistoryItem(BaseModel):
role: Literal["user", "agent"] role: Literal["user", "agent", "system"]
# content: str # content: str
content: Union[str, List[Dict]] content: Union[str, List[Dict]]