webrl/VAB-WebArena-Lite/browser_env/auto_login.py
2025-04-23 17:01:18 +08:00

183 lines
6.2 KiB
Python

"""Script to automatically login each website"""
import argparse
import glob
import os
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import combinations
from pathlib import Path
from playwright.sync_api import sync_playwright
from browser_env.env_config import ACCOUNTS
DATASET = os.environ["DATASET"]
if DATASET == "webarena":
from browser_env.env_config import (
GITLAB,
REDDIT,
SHOPPING,
SHOPPING_ADMIN,
)
SITES = ["gitlab", "shopping", "shopping_admin", "reddit"]
URLS = [
f"{GITLAB}/-/profile",
f"{SHOPPING}/wishlist/",
f"{SHOPPING_ADMIN}/dashboard",
f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account",
]
EXACT_MATCH = [True, True, True, True]
KEYWORDS = ["", "", "Dashboard", "Delete"]
elif DATASET == "visualwebarena":
from browser_env.env_config import (
CLASSIFIEDS,
REDDIT,
SHOPPING,
)
SITES = ["shopping", "reddit", "classifieds"]
URLS = [
f"{SHOPPING}/wishlist/",
f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account",
f"{CLASSIFIEDS}/index.php?page=user&action=items",
]
EXACT_MATCH = [True, True, True]
KEYWORDS = ["", "Delete", "My listings"]
else:
raise ValueError(f"Dataset not implemented: {DATASET}")
HEADLESS = True
SLOW_MO = 0
assert len(SITES) == len(URLS) == len(EXACT_MATCH) == len(KEYWORDS)
def is_expired(
storage_state: Path, url: str, keyword: str, url_exact: bool = True
) -> bool:
"""Test whether the cookie is expired"""
if not storage_state.exists():
return True
context_manager = sync_playwright()
playwright = context_manager.__enter__()
browser = playwright.chromium.launch(headless=True, slow_mo=SLOW_MO)
context = browser.new_context(storage_state=storage_state)
page = context.new_page()
page.goto(url)
time.sleep(1)
d_url = page.url
content = page.content()
context_manager.__exit__()
if keyword:
return keyword not in content
else:
if url_exact:
return d_url != url
else:
return url not in d_url
def renew_comb(comb: list[str], auth_folder: str = "./.auth") -> None:
context_manager = sync_playwright()
playwright = context_manager.__enter__()
browser = playwright.chromium.launch(headless=HEADLESS)
context = browser.new_context()
page = context.new_page()
if "shopping" in comb:
username = ACCOUNTS["shopping"]["username"]
password = ACCOUNTS["shopping"]["password"]
page.goto(f"{SHOPPING}/customer/account/login/")
page.get_by_label("Email", exact=True).fill(username)
page.get_by_label("Password", exact=True).fill(password)
page.get_by_role("button", name="Sign In").click()
if "reddit" in comb:
username = ACCOUNTS["reddit"]["username"]
password = ACCOUNTS["reddit"]["password"]
page.goto(f"{REDDIT}/login")
page.get_by_label("Username").fill(username)
page.get_by_label("Password").fill(password)
page.get_by_role("button", name="Log in").click()
if "classifieds" in comb:
username = ACCOUNTS["classifieds"]["username"]
password = ACCOUNTS["classifieds"]["password"]
page.goto(f"{CLASSIFIEDS}/index.php?page=login")
page.locator("#email").fill(username)
page.locator("#password").fill(password)
page.get_by_role("button", name="Log in").click()
if "shopping_admin" in comb:
username = ACCOUNTS["shopping_admin"]["username"]
password = ACCOUNTS["shopping_admin"]["password"]
page.goto(f"{SHOPPING_ADMIN}")
page.get_by_placeholder("user name").fill(username)
page.get_by_placeholder("password").fill(password)
page.get_by_role("button", name="Sign in").click()
if "gitlab" in comb:
username = ACCOUNTS["gitlab"]["username"]
password = ACCOUNTS["gitlab"]["password"]
page.goto(f"{GITLAB}/users/sign_in")
page.get_by_test_id("username-field").click()
page.get_by_test_id("username-field").fill(username)
page.get_by_test_id("username-field").press("Tab")
page.get_by_test_id("password-field").fill(password)
page.get_by_test_id("sign-in-button").click()
context.storage_state(path=f"{auth_folder}/{'.'.join(comb)}_state.json")
context_manager.__exit__()
def get_site_comb_from_filepath(file_path: str) -> list[str]:
comb = os.path.basename(file_path).rsplit("_", 1)[0].split(".")
return comb
def main(auth_folder: str = "./.auth") -> None:
pairs = list(combinations(SITES, 2))
with ThreadPoolExecutor(max_workers=8) as executor:
for pair in pairs:
# Auth doesn't work on this pair as they share the same cookie
if "reddit" in pair and (
"shopping" in pair or "shopping_admin" in pair
):
continue
executor.submit(
renew_comb, list(sorted(pair)), auth_folder=auth_folder
)
for site in SITES:
executor.submit(renew_comb, [site], auth_folder=auth_folder)
# parallel checking if the cookies are expired
futures = []
cookie_files = list(glob.glob(f"{auth_folder}/*.json"))
with ThreadPoolExecutor(max_workers=8) as executor:
for c_file in cookie_files:
comb = get_site_comb_from_filepath(c_file)
for cur_site in comb:
url = URLS[SITES.index(cur_site)]
keyword = KEYWORDS[SITES.index(cur_site)]
match = EXACT_MATCH[SITES.index(cur_site)]
future = executor.submit(
is_expired, Path(c_file), url, keyword, match
)
futures.append(future)
for i, future in enumerate(futures):
assert not future.result(), f"Cookie {cookie_files[i]} expired."
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--site_list", nargs="+", default=[])
parser.add_argument("--auth_folder", type=str, default="./.auth")
args = parser.parse_args()
if not args.site_list:
main()
else:
renew_comb(args.site_list, auth_folder=args.auth_folder)