init
This commit is contained in:
parent
dac181a456
commit
4b8dbfd9c5
1448
AgentOccam/AgentOccam.py
Normal file
1448
AgentOccam/AgentOccam.py
Normal file
File diff suppressed because it is too large
Load Diff
2
AgentOccam/__init__.py
Normal file
2
AgentOccam/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .obs_opt import parse_node_descendants, parse_node_ancestors, parse_node_siblings, action_set_invisible, action_set_visible, action_set_visible_if_with_name, translate_node_to_str, construct_new_DOM_with_visible_nodes
|
||||
from .utils import CURRENT_DIR, HOMEPAGE_URL
|
78
AgentOccam/configs/AgentOccam-Judge.yml
Normal file
78
AgentOccam/configs/AgentOccam-Judge.yml
Normal file
|
@ -0,0 +1,78 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "AgentOccam-Judge"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam"
|
||||
others:
|
||||
max_steps: 20
|
||||
logname: "AgentOccam-Judge"
|
||||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: ["url", "plan", "reason", "observation summary", "retained element ids", "observation highlight"]
|
||||
online_interaction_elements: []
|
||||
input: ["step", "objective", "previous plans", "interaction history", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["interaction history summary", "observation description", "action candidates", "observation highlight"]
|
||||
planning_command: ["branch", "prune"]
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response"]
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
character: "normal"
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "mistakes"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
judge:
|
||||
mode: true
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
strict: false
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation", "action choices"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["plan progress assessment", "action assessment", "action selection"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
env:
|
||||
fullpage: true
|
||||
prune: true
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
26
AgentOccam/configs/AgentOccam-SteP.yml
Normal file
26
AgentOccam/configs/AgentOccam-SteP.yml
Normal file
|
@ -0,0 +1,26 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "AgentOccam-SteP"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam-SteP"
|
||||
root_action: "shopping_admin_agent" # Need to be adapted to tasks
|
||||
low_level_action_list: ['click', 'type', 'stop', 'goto', 'hover', 'note', 'go_back']
|
||||
model_name: "gpt-4-turbo"
|
||||
model_host: "openai"
|
||||
prompt_mode: "chat"
|
||||
max_target_len: 100
|
||||
env:
|
||||
fullpage: true
|
||||
prune: true
|
||||
max_env_steps: 20
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
74
AgentOccam/configs/AgentOccam-WebVoyager.yml
Normal file
74
AgentOccam/configs/AgentOccam-WebVoyager.yml
Normal file
|
@ -0,0 +1,74 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "AgentOccam-WebVoyager"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam"
|
||||
others:
|
||||
max_steps: 20
|
||||
logname: "AgentOccam-WebVoyager"
|
||||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: ["url", "plan", "reason", "observation summary", "retained element ids", "observation highlight"]
|
||||
online_interaction_elements: []
|
||||
input: ["step", "objective", "previous plans", "interaction history", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["interaction history summary", "observation description", "reason", "action", "observation highlight"]
|
||||
planning_command: ["branch", "prune"]
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response"]
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
character: "normal"
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "mistakes"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
judge:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
strict: false
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation", "action choices"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["plan progress assessment", "action assessment", "action selection"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
env:
|
||||
fullpage: true
|
||||
prune: true
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
relative_task_dir: "webvoyager"
|
||||
task_ids: ["Allrecipes--3"]
|
78
AgentOccam/configs/AgentOccam.yml
Normal file
78
AgentOccam/configs/AgentOccam.yml
Normal file
|
@ -0,0 +1,78 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "AgentOccam"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam"
|
||||
others:
|
||||
max_steps: 20
|
||||
logname: "AgentOccam"
|
||||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: ["url", "plan", "reason", "observation summary", "retained element ids", "observation highlight"]
|
||||
online_interaction_elements: []
|
||||
input: ["step", "objective", "previous plans", "interaction history", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["interaction history summary", "observation description", "reason", "action", "observation highlight"]
|
||||
planning_command: ["branch", "prune"]
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response"]
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
character: "normal"
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "mistakes"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
judge:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
strict: false
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation", "action choices"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["plan progress assessment", "action assessment", "action selection"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
env:
|
||||
fullpage: true
|
||||
prune: true
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
26
AgentOccam/configs/SteP-replication.yml
Normal file
26
AgentOccam/configs/SteP-replication.yml
Normal file
|
@ -0,0 +1,26 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "SteP-replication"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "SteP-replication"
|
||||
root_action: "shopping_admin_agent" # Need to be adapted to tasks
|
||||
low_level_action_list: ['click', 'type', 'scroll', 'stop', 'goto', 'hover', 'note', 'go_back']
|
||||
model_name: "gpt-4-turbo"
|
||||
model_host: "openai"
|
||||
prompt_mode: "chat"
|
||||
max_target_len: 100
|
||||
env:
|
||||
fullpage: false
|
||||
prune: false
|
||||
max_env_steps: 20
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
|
@ -0,0 +1,78 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "reduced_action-X_scrolling-obs_opt-history"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam"
|
||||
others:
|
||||
max_steps: 20
|
||||
logname: "reduced_action-X_scrolling-obs_opt-history"
|
||||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: ["url", "plan", "reason", "observation summary", "retained element ids", "observation highlight"]
|
||||
online_interaction_elements: []
|
||||
input: ["step", "objective", "interaction history", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["interaction history summary", "observation description", "reason", "action", "observation highlight"]
|
||||
planning_command: []
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response"]
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
character: "normal"
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "mistakes"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
judge:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
strict: false
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation", "action choices"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["plan progress assessment", "action assessment", "action selection"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
env:
|
||||
fullpage: true
|
||||
prune: true
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
78
AgentOccam/configs/reduced_action-X_scrolling-obs_opt.yml
Normal file
78
AgentOccam/configs/reduced_action-X_scrolling-obs_opt.yml
Normal file
|
@ -0,0 +1,78 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "reduced_action-X_scrolling-obs_opt"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam"
|
||||
others:
|
||||
max_steps: 20
|
||||
logname: "reduced_action-X_scrolling-obs_opt"
|
||||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: ["url", "plan", "reason", "observation summary", "retained element ids", "observation highlight"]
|
||||
online_interaction_elements: []
|
||||
input: ["step", "objective", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "reason", "action", "observation highlight"]
|
||||
planning_command: []
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response"]
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
character: "normal"
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "mistakes"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
judge:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
strict: false
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation", "action choices"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["plan progress assessment", "action assessment", "action selection"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
env:
|
||||
fullpage: true
|
||||
prune: true
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
78
AgentOccam/configs/reduced_action-X_scrolling.yml
Normal file
78
AgentOccam/configs/reduced_action-X_scrolling.yml
Normal file
|
@ -0,0 +1,78 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "reduced_action-X_scrolling"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam"
|
||||
others:
|
||||
max_steps: 20
|
||||
logname: "reduced_action-X_scrolling"
|
||||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: ["url", "plan", "reason", "observation summary", "retained element ids", "observation highlight"]
|
||||
online_interaction_elements: []
|
||||
input: ["step", "objective", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "reason", "action", "observation highlight"]
|
||||
planning_command: []
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response"]
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
character: "normal"
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "mistakes"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
judge:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
strict: false
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation", "action choices"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["plan progress assessment", "action assessment", "action selection"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
env:
|
||||
fullpage: true
|
||||
prune: false
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
78
AgentOccam/configs/reduced_action.yml
Normal file
78
AgentOccam/configs/reduced_action.yml
Normal file
|
@ -0,0 +1,78 @@
|
|||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
logname: "reduced_action"
|
||||
max_steps: 20
|
||||
agent:
|
||||
type: "AgentOccam"
|
||||
others:
|
||||
max_steps: 20
|
||||
logname: "reduced_action"
|
||||
logging: True
|
||||
verbose: 1
|
||||
debug: False
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: ["url", "plan", "reason", "observation summary", "retained element ids", "observation highlight"]
|
||||
online_interaction_elements: []
|
||||
input: ["step", "objective", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "reason", "action", "observation highlight"]
|
||||
planning_command: []
|
||||
navigation_command: ["click", "type", "scroll", "stop", "note", "go_back"]
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response"]
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
character: "normal"
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["observation description", "mistakes"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
judge:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
model: "gpt-4-turbo"
|
||||
documented_interaction_elements: []
|
||||
online_interaction_elements: []
|
||||
strict: false
|
||||
input: ["objective", "previous plans", "interaction history", "step", "current observation", "action choices"]
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: "all"
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
output: ["plan progress assessment", "action assessment", "action selection"]
|
||||
trash: ["instruction", "online input", "response"]
|
||||
env:
|
||||
fullpage: false
|
||||
prune: false
|
||||
max_browser_rows: 500
|
||||
headless: True
|
||||
task_ids: ["stanford_cs_head", 65]
|
||||
# a. "SHOPPING_ADMIN": [0, 1, 2, 3, 4, 5, 6, 11, 12, 13, 14, 15, 41, 42, 43, 62, 63, 64, 65, 77, 78, 79, 94, 95, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 119, 120, 121, 122, 123, 127, 128, 129, 130, 131, 157, 183, 184, 185, 186, 187, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 243, 244, 245, 246, 247, 288, 289, 290, 291, 292, 344, 345, 346, 347, 348, 374, 375, 423, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 470, 471, 472, 473, 474, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 676, 677, 678, 679, 680, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 790]
|
||||
# b. "MAP": [7, 8, 9, 10, 16, 17, 18, 19, 20, 32, 33, 34, 35, 36, 37, 38, 39, 40, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 70, 71, 72, 73, 74, 75, 76, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 97, 98, 99, 100, 101, 137, 138, 139, 140, 151, 152, 153, 154, 155, 218, 219, 220, 221, 222, 223, 224, 236, 237, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 265, 266, 267, 268, 287, 356, 363, 364, 365, 366, 367, 369, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 383, 424, 425, 426, 427, 428, 429, 430, 737, 738, 739, 740, 741, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767]
|
||||
# c. "SHOPPING": [21, 22, 23, 24, 25, 26, 47, 48, 49, 50, 51, 96, 117, 118, 124, 125, 126, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 188, 189, 190, 191, 192, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 238, 239, 240, 241, 242, 260, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 298, 299, 300, 301, 302, 313, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 351, 352, 353, 354, 355, 358, 359, 360, 361, 362, 368, 376, 384, 385, 386, 387, 388, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 465, 466, 467, 468, 469, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 528, 529, 530, 531, 532, 571, 572, 573, 574, 575, 585, 586, 587, 588, 589, 653, 654, 655, 656, 657, 671, 672, 673, 674, 675, 689, 690, 691, 692, 693, 792, 793, 794, 795, 796, 797, 798]
|
||||
# d. "REDDIT": [27, 28, 29, 30, 31, 66, 67, 68, 69, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 580, 581, 582, 583, 584, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 681, 682, 683, 684, 685, 686, 687, 688, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735]
|
||||
# e. "GITLAB": [44, 45, 46, 102, 103, 104, 105, 106, 132, 133, 134, 135, 136, 156, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 205, 206, 207, 258, 259, 293, 294, 295, 296, 297, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 314, 315, 316, 317, 318, 339, 340, 341, 342, 343, 349, 350, 357, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 522, 523, 524, 525, 526, 527, 533, 534, 535, 536, 537, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 576, 577, 578, 579, 590, 591, 592, 593, 594, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 736, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 783, 784, 785, 786, 787, 788, 789, 791, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811]
|
129
AgentOccam/env.py
Normal file
129
AgentOccam/env.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import json
|
||||
from browser_env import (
|
||||
create_id_based_action,
|
||||
create_id_based_actions,
|
||||
StateInfo,
|
||||
Trajectory,
|
||||
ActionTypes,
|
||||
ScriptBrowserEnv
|
||||
)
|
||||
from evaluation_harness.evaluators import evaluator_router
|
||||
from AgentOccam.obs_opt import (
|
||||
prune_tree,
|
||||
translate_node_to_str,
|
||||
)
|
||||
|
||||
|
||||
class WebArenaEnvironmentWrapper():
|
||||
def __init__(self, config_file, max_browser_rows=300, max_steps=50, slow_mo=1, observation_type="accessibility_tree", current_viewport_only=False, viewport_size={"width": 1280, "height": 720}, headless=False, global_config=None):
|
||||
self.webarena_env = ScriptBrowserEnv(
|
||||
headless=headless,
|
||||
slow_mo=slow_mo,
|
||||
observation_type=observation_type,
|
||||
current_viewport_only=current_viewport_only,
|
||||
viewport_size=viewport_size,
|
||||
global_config=global_config
|
||||
)
|
||||
self.config_file = config_file
|
||||
with open(self.config_file, "r") as f:
|
||||
self.config = json.load(f)
|
||||
self.global_config = global_config
|
||||
|
||||
self.obs, self.info = self.webarena_env.reset(options={"config_file": self.config_file})
|
||||
self.terminated = False
|
||||
self.objective = self.config["intent"]
|
||||
self.url = self.config["start_url"]
|
||||
self.max_browser_rows = max_browser_rows
|
||||
self.max_steps = max_steps
|
||||
self.steps = 0
|
||||
self.is_done = False
|
||||
self.reward = 0.0
|
||||
|
||||
self.trajectory: Trajectory = []
|
||||
self.update_webarena_metrics()
|
||||
|
||||
def reset(self):
|
||||
self.obs, self.info = self.webarena_env.reset(options={"config_file": self.config_file})
|
||||
|
||||
def close(self):
|
||||
self.webarena_env.close()
|
||||
|
||||
def get_url(self):
|
||||
return self.url
|
||||
|
||||
def get_objective(self):
|
||||
return self.objective
|
||||
|
||||
def get_sites(self):
|
||||
return self.config["sites"]
|
||||
|
||||
def observation(self):
|
||||
self.url = self.webarena_env.page.url
|
||||
if self.global_config and self.global_config.env.prune:
|
||||
root_node = self.obs["text"][1]
|
||||
DOM_root_node = prune_tree(objective=self.objective, root_node=root_node, mode="node")
|
||||
DOM_str = translate_node_to_str(node=DOM_root_node, mode="concise")
|
||||
return {"text": DOM_str, "image": self.obs["image"], "node": DOM_root_node}
|
||||
else:
|
||||
browser_content = self.obs["text"][0]
|
||||
browser_content = browser_content.split("\n")[:self.max_browser_rows]
|
||||
browser_content = "\n".join(browser_content)
|
||||
return browser_content
|
||||
|
||||
def done(self):
|
||||
if self.is_done:
|
||||
return True
|
||||
return False
|
||||
|
||||
def status(self):
|
||||
return {'done': self.is_done, 'reward': self.reward, 'success': float(self.reward > 0), 'num_actions': self.steps}
|
||||
|
||||
def step(self, action):
|
||||
self.steps = self.steps + 1
|
||||
print(f"[Step {self.steps}] {action}")
|
||||
print("*"*100)
|
||||
if self.steps > self.max_steps:
|
||||
print(f"Steps {self.steps} exceeded maximum {self.max_steps}")
|
||||
self.is_done = True
|
||||
action_cmd = create_id_based_action(f"stop [Trajectory failed: Steps {self.steps} exceeded maximum {self.max_steps}.]")
|
||||
self.update_webarena_metrics(action_cmd)
|
||||
return self.status()
|
||||
|
||||
if action is None or action == "":
|
||||
action_cmds = []
|
||||
else:
|
||||
try:
|
||||
action_cmds = create_id_based_actions(action)
|
||||
if not action_cmds:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Invalid action syntax: {e}")
|
||||
action_cmds = []
|
||||
|
||||
for action_cmd in action_cmds:
|
||||
try:
|
||||
self.obs, _, self.terminated, _, self.info = self.webarena_env.step(action_cmd)
|
||||
self.update_webarena_metrics(action_cmd)
|
||||
except Exception as e:
|
||||
print(f"Error occurred while taking step: {e}")
|
||||
|
||||
return self.status()
|
||||
|
||||
def update_webarena_metrics(self, action_cmd=None):
|
||||
# Append action (if any) and resulting sate
|
||||
if action_cmd:
|
||||
self.trajectory.append(action_cmd)
|
||||
if action_cmd["action_type"]== ActionTypes.STOP:
|
||||
self.is_done = True
|
||||
|
||||
if not self.is_done: # If we are done, no need to append state
|
||||
state_info: StateInfo = {"observation": self.obs, "info": self.info}
|
||||
self.trajectory.append(state_info)
|
||||
|
||||
if self.is_done:
|
||||
try:
|
||||
evaluator = evaluator_router(self.config_file)
|
||||
self.reward = evaluator(trajectory=self.trajectory, config_file=self.config_file, page=self.webarena_env.page, client=self.webarena_env.get_page_client(self.webarena_env.page))
|
||||
except Exception as e:
|
||||
print(f"Got excepetion: {e}")
|
||||
self.reward = 0
|
213
AgentOccam/llms/claude.py
Normal file
213
AgentOccam/llms/claude.py
Normal file
|
@ -0,0 +1,213 @@
|
|||
import boto3
|
||||
import json
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = '''You are an AI assistant. Your goal is to provide informative and substantive responses to queries.'''
|
||||
|
||||
def call_claude(prompt, model_id="anthropic.claude-3-sonnet-20240229-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
client = boto3.client("bedrock-runtime", region_name="us-east-1")
|
||||
|
||||
native_request = {
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.95,
|
||||
"system": system_prompt,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
request = json.dumps(native_request)
|
||||
|
||||
num_attempts = 0
|
||||
while True:
|
||||
if num_attempts >= 10:
|
||||
raise ValueError("OpenAI request failed.")
|
||||
try:
|
||||
response = client.invoke_model(modelId=model_id, body=request)
|
||||
model_response = json.loads(response["body"].read())
|
||||
|
||||
response_text = model_response["content"][0]["text"]
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Sleeping for 10s...")
|
||||
time.sleep(10)
|
||||
num_attempts += 1
|
||||
|
||||
|
||||
|
||||
def arrange_message_for_claude(item_list):
|
||||
def image_path_to_bytes(file_path):
|
||||
with open(file_path, "rb") as image_file:
|
||||
image_bytes = image_file.read()
|
||||
return image_bytes
|
||||
combined_item_list = []
|
||||
previous_item_is_text = False
|
||||
text_buffer = ""
|
||||
for item in item_list:
|
||||
if item[0] == "image":
|
||||
if len(text_buffer) > 0:
|
||||
combined_item_list.append(("text", text_buffer))
|
||||
text_buffer = ""
|
||||
combined_item_list.append(item)
|
||||
previous_item_is_text = False
|
||||
else:
|
||||
if previous_item_is_text:
|
||||
text_buffer += item[1]
|
||||
else:
|
||||
text_buffer = item[1]
|
||||
previous_item_is_text = True
|
||||
if item_list[-1][0] != "image" and len(text_buffer) > 0:
|
||||
combined_item_list.append(("text", text_buffer))
|
||||
content = []
|
||||
for item in combined_item_list:
|
||||
item_type = item[0]
|
||||
if item_type == "text":
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": item[1]
|
||||
})
|
||||
elif item_type == "image":
|
||||
if isinstance(item[1], str):
|
||||
media_type = "image/png" # "image/jpeg"
|
||||
image_bytes = image_path_to_bytes(item[1])
|
||||
image_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
elif isinstance(item[1], np.ndarray):
|
||||
media_type = "image/jpeg"
|
||||
image = Image.fromarray(item[1]).convert("RGB")
|
||||
width, height = image.size
|
||||
image = image.resize((int(0.5*width), int(0.5*height)), Image.LANCZOS)
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, format='JPEG')
|
||||
image_bytes = image_bytes.getvalue()
|
||||
image_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": image_data,
|
||||
},
|
||||
})
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
]
|
||||
return messages
|
||||
|
||||
def call_claude_with_messages(messages, model_id="anthropic.claude-3-sonnet-20240229-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
client = boto3.client("bedrock-runtime", region_name="us-east-1")
|
||||
|
||||
native_request = {
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.95,
|
||||
"system": system_prompt,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
request = json.dumps(native_request)
|
||||
|
||||
num_attempts = 0
|
||||
while True:
|
||||
if num_attempts >= 10:
|
||||
raise ValueError("OpenAI request failed.")
|
||||
try:
|
||||
response = client.invoke_model(modelId=model_id, body=request)
|
||||
model_response = json.loads(response["body"].read())
|
||||
|
||||
response_text = model_response["content"][0]["text"]
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Sleeping for 10s...")
|
||||
time.sleep(10)
|
||||
num_attempts += 1
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(call_claude('''CURRENT OBSERVATION:
|
||||
RootWebArea [2634] 'My Account'
|
||||
link [3987] 'My Account'
|
||||
link [3985] 'My Wish List'
|
||||
link [3989] 'Sign Out'
|
||||
text 'Welcome to One Stop Market'
|
||||
link [3800] 'Skip to Content'
|
||||
link [3809] 'store logo'
|
||||
link [3996] 'My Cart'
|
||||
combobox [4190] 'Search' [required: False]
|
||||
link [4914] 'Advanced Search'
|
||||
button [4193] 'Search' [disabled: True]
|
||||
tablist [3699]
|
||||
tabpanel
|
||||
menu "[3394] 'Beauty & Personal Care'; [3459] 'Sports & Outdoors'; [3469] 'Clothing, Shoes & Jewelry'; [3483] 'Home & Kitchen'; [3520] 'Office Products'; [3528] 'Tools & Home Improvement'; [3533] 'Health & Household'; [3539] 'Patio, Lawn & Garden'; [3544] 'Electronics'; [3605] 'Cell Phones & Accessories'; [3620] 'Video Games'; [3633] 'Grocery & Gourmet Food'"
|
||||
main
|
||||
heading 'My Account'
|
||||
text 'Contact Information'
|
||||
text 'Emma Lopez'
|
||||
text 'emma.lopezgmail.com'
|
||||
link [3863] 'Change Password'
|
||||
text 'Newsletters'
|
||||
text "You aren't subscribed to our newsletter."
|
||||
link [3877] 'Manage Addresses'
|
||||
text 'Default Billing Address'
|
||||
group [3885]
|
||||
text 'Emma Lopez'
|
||||
text '101 S San Mateo Dr'
|
||||
text 'San Mateo, California, 94010'
|
||||
text 'United States'
|
||||
text 'T:'
|
||||
link [3895] '6505551212'
|
||||
text 'Default Shipping Address'
|
||||
group [3902]
|
||||
text 'Emma Lopez'
|
||||
text '101 S San Mateo Dr'
|
||||
text 'San Mateo, California, 94010'
|
||||
text 'United States'
|
||||
text 'T:'
|
||||
link [3912] '6505551212'
|
||||
link [3918] 'View All'
|
||||
table 'Recent Orders'
|
||||
row '| Order | Date | Ship To | Order Total | Status | Action |'
|
||||
row '| --- | --- | --- | --- | --- | --- |'
|
||||
row "| 000000170 | 5/17/23 | Emma Lopez | 365.42 | Canceled | View OrderReorder\tlink [4110] 'View Order'\tlink [4111] 'Reorder' |"
|
||||
row "| 000000189 | 5/2/23 | Emma Lopez | 754.99 | Pending | View OrderReorder\tlink [4122] 'View Order'\tlink [4123] 'Reorder' |"
|
||||
row "| 000000188 | 5/2/23 | Emma Lopez | 2,004.99 | Pending | View OrderReorder\tlink [4134] 'View Order'\tlink [4135] 'Reorder' |"
|
||||
row "| 000000187 | 5/2/23 | Emma Lopez | 1,004.99 | Pending | View OrderReorder\tlink [4146] 'View Order'\tlink [4147] 'Reorder' |"
|
||||
row "| 000000180 | 3/11/23 | Emma Lopez | 65.32 | Complete | View OrderReorder\tlink [4158] 'View Order'\tlink [4159] 'Reorder' |"
|
||||
link [4165] 'My Orders'
|
||||
link [4166] 'My Downloadable Products'
|
||||
link [4167] 'My Wish List'
|
||||
link [4169] 'Address Book'
|
||||
link [4170] 'Account Information'
|
||||
link [4171] 'Stored Payment Methods'
|
||||
link [4173] 'My Product Reviews'
|
||||
link [4174] 'Newsletter Subscriptions'
|
||||
heading 'Compare Products'
|
||||
text 'You have no items to compare.'
|
||||
heading 'My Wish List'
|
||||
text 'You have no items in your wish list.'
|
||||
contentinfo
|
||||
textbox [4177] 'Sign Up for Our Newsletter:' [required: False]
|
||||
button [4072] 'Subscribe'
|
||||
link [4073] 'Privacy and Cookie Policy'
|
||||
link [4074] 'Search Terms'
|
||||
link [4075] 'Advanced Search'
|
||||
link [4076] 'Contact Us'
|
||||
text 'Copyright 2013-present Magento, Inc. All rights reserved.'
|
||||
text 'Help Us Keep Magento Healthy'
|
||||
link [3984] 'Report All Bugs'
|
||||
Today is 6/12/2023. Base on the webpage, tell me how many fulfilled orders I have over the past month, and the total amount of money I spent over the past month.'''))
|
42
AgentOccam/llms/cohere.py
Normal file
42
AgentOccam/llms/cohere.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import boto3
|
||||
import json
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = '''You are an AI assistant. Your goal is to provide informative and substantive responses to queries.'''
|
||||
|
||||
def call_cohere(prompt, model_id="cohere.command-r-plus-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
client = boto3.client("bedrock-runtime", region_name="us-east-1")
|
||||
|
||||
formatted_prompt = f"{system_prompt}\n{prompt}"
|
||||
|
||||
native_request = {
|
||||
"message": formatted_prompt,
|
||||
"max_tokens": 512,
|
||||
"temperature": 0.5,
|
||||
}
|
||||
|
||||
request = json.dumps(native_request)
|
||||
try:
|
||||
response = client.invoke_model(modelId=model_id, body=request)
|
||||
|
||||
except (ClientError, Exception) as e:
|
||||
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
|
||||
|
||||
model_response = json.loads(response["body"].read())
|
||||
|
||||
response_text = model_response["text"]
|
||||
return response_text
|
||||
|
||||
def arrange_message_for_cohere(item_list):
|
||||
for item in item_list:
|
||||
if item[0] == "image":
|
||||
raise NotImplementedError()
|
||||
prompt = "".join([item[1] for item in item_list])
|
||||
return prompt
|
||||
|
||||
def call_cohere_with_messages(messages, model_id="cohere.command-r-plus-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
return call_cohere(prompt=messages, model_id=model_id, system_prompt=system_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(call_cohere('''Hi'''))
|
||||
|
107
AgentOccam/llms/gemini.py
Normal file
107
AgentOccam/llms/gemini.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import google.generativeai as genai
|
||||
import os
|
||||
import time
|
||||
|
||||
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
||||
genai.configure(api_key=GEMINI_API_KEY)
|
||||
|
||||
|
||||
def call_gemini(prompt, model_id="gemini-1.5-flash", system_prompt=None):
|
||||
model = genai.GenerativeModel(model_id)
|
||||
|
||||
num_attempts = 0
|
||||
while True:
|
||||
if num_attempts >= 10:
|
||||
raise ValueError("Gemini request failed.")
|
||||
try:
|
||||
response = model.generate_content(system_prompt+"\n"+prompt)
|
||||
response_text = response.text
|
||||
return response_text
|
||||
except Exception as e:
|
||||
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
|
||||
time.sleep(30)
|
||||
|
||||
|
||||
def arrange_message_for_gemini(item_list):
|
||||
for item in item_list:
|
||||
if item[0] == "image":
|
||||
raise NotImplementedError()
|
||||
prompt = "".join([item[1] for item in item_list])
|
||||
return prompt
|
||||
|
||||
def call_gemini_with_messages(messages, model_id="gemini-1.5-flash", system_prompt=None):
|
||||
return call_gemini(prompt=messages, model_id=model_id, system_prompt=system_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(call_gemini('''CURRENT OBSERVATION:
|
||||
RootWebArea [2634] 'My Account'
|
||||
link [3987] 'My Account'
|
||||
link [3985] 'My Wish List'
|
||||
link [3989] 'Sign Out'
|
||||
text 'Welcome to One Stop Market'
|
||||
link [3800] 'Skip to Content'
|
||||
link [3809] 'store logo'
|
||||
link [3996] 'My Cart'
|
||||
combobox [4190] 'Search' [required: False]
|
||||
link [4914] 'Advanced Search'
|
||||
button [4193] 'Search' [disabled: True]
|
||||
tablist [3699]
|
||||
tabpanel
|
||||
menu "[3394] 'Beauty & Personal Care'; [3459] 'Sports & Outdoors'; [3469] 'Clothing, Shoes & Jewelry'; [3483] 'Home & Kitchen'; [3520] 'Office Products'; [3528] 'Tools & Home Improvement'; [3533] 'Health & Household'; [3539] 'Patio, Lawn & Garden'; [3544] 'Electronics'; [3605] 'Cell Phones & Accessories'; [3620] 'Video Games'; [3633] 'Grocery & Gourmet Food'"
|
||||
main
|
||||
heading 'My Account'
|
||||
text 'Contact Information'
|
||||
text 'Emma Lopez'
|
||||
text 'emma.lopezgmail.com'
|
||||
link [3863] 'Change Password'
|
||||
text 'Newsletters'
|
||||
text "You aren't subscribed to our newsletter."
|
||||
link [3877] 'Manage Addresses'
|
||||
text 'Default Billing Address'
|
||||
group [3885]
|
||||
text 'Emma Lopez'
|
||||
text '101 S San Mateo Dr'
|
||||
text 'San Mateo, California, 94010'
|
||||
text 'United States'
|
||||
text 'T:'
|
||||
link [3895] '6505551212'
|
||||
text 'Default Shipping Address'
|
||||
group [3902]
|
||||
text 'Emma Lopez'
|
||||
text '101 S San Mateo Dr'
|
||||
text 'San Mateo, California, 94010'
|
||||
text 'United States'
|
||||
text 'T:'
|
||||
link [3912] '6505551212'
|
||||
link [3918] 'View All'
|
||||
table 'Recent Orders'
|
||||
row '| Order | Date | Ship To | Order Total | Status | Action |'
|
||||
row '| --- | --- | --- | --- | --- | --- |'
|
||||
row "| 000000170 | 5/17/23 | Emma Lopez | 365.42 | Canceled | View OrderReorder\tlink [4110] 'View Order'\tlink [4111] 'Reorder' |"
|
||||
row "| 000000189 | 5/2/23 | Emma Lopez | 754.99 | Pending | View OrderReorder\tlink [4122] 'View Order'\tlink [4123] 'Reorder' |"
|
||||
row "| 000000188 | 5/2/23 | Emma Lopez | 2,004.99 | Pending | View OrderReorder\tlink [4134] 'View Order'\tlink [4135] 'Reorder' |"
|
||||
row "| 000000187 | 5/2/23 | Emma Lopez | 1,004.99 | Pending | View OrderReorder\tlink [4146] 'View Order'\tlink [4147] 'Reorder' |"
|
||||
row "| 000000180 | 3/11/23 | Emma Lopez | 65.32 | Complete | View OrderReorder\tlink [4158] 'View Order'\tlink [4159] 'Reorder' |"
|
||||
link [4165] 'My Orders'
|
||||
link [4166] 'My Downloadable Products'
|
||||
link [4167] 'My Wish List'
|
||||
link [4169] 'Address Book'
|
||||
link [4170] 'Account Information'
|
||||
link [4171] 'Stored Payment Methods'
|
||||
link [4173] 'My Product Reviews'
|
||||
link [4174] 'Newsletter Subscriptions'
|
||||
heading 'Compare Products'
|
||||
text 'You have no items to compare.'
|
||||
heading 'My Wish List'
|
||||
text 'You have no items in your wish list.'
|
||||
contentinfo
|
||||
textbox [4177] 'Sign Up for Our Newsletter:' [required: False]
|
||||
button [4072] 'Subscribe'
|
||||
link [4073] 'Privacy and Cookie Policy'
|
||||
link [4074] 'Search Terms'
|
||||
link [4075] 'Advanced Search'
|
||||
link [4076] 'Contact Us'
|
||||
text 'Copyright 2013-present Magento, Inc. All rights reserved.'
|
||||
text 'Help Us Keep Magento Healthy'
|
||||
link [3984] 'Report All Bugs'
|
||||
Today is 6/12/2023. Base on the aforementioned webpage, tell me how many fulfilled orders I have over the past month, and the total amount of money I spent over the past month.'''))
|
222
AgentOccam/llms/gpt.py
Normal file
222
AgentOccam/llms/gpt.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
import openai
|
||||
from openai import OpenAI, AzureOpenAI
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import base64
|
||||
import io
|
||||
import requests
|
||||
import os
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None)
|
||||
AZURE_ENDPOINT = os.environ.get("AZURE_ENDPOINT", None)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}"
|
||||
}
|
||||
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
||||
|
||||
def call_gpt(prompt, model_id="gpt-3.5-turbo", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
num_attempts = 0
|
||||
while True:
|
||||
if num_attempts >= 10:
|
||||
raise ValueError("OpenAI request failed.")
|
||||
try:
|
||||
response = OpenAI().chat.completions.create(
|
||||
model=model_id,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.95,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
stop=None
|
||||
)
|
||||
|
||||
return response.choices[0].message.content.strip()
|
||||
except openai.AuthenticationError as e:
|
||||
print(e)
|
||||
return None
|
||||
except openai.RateLimitError as e:
|
||||
print(e)
|
||||
print("Sleeping for 10s...")
|
||||
time.sleep(10)
|
||||
num_attempts += 1
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Sleeping for 10s...")
|
||||
time.sleep(10)
|
||||
num_attempts += 1
|
||||
|
||||
def arrange_message_for_gpt(item_list):
|
||||
def image_path_to_bytes(file_path):
|
||||
with open(file_path, "rb") as image_file:
|
||||
image_bytes = image_file.read()
|
||||
return image_bytes
|
||||
combined_item_list = []
|
||||
previous_item_is_text = False
|
||||
text_buffer = ""
|
||||
for item in item_list:
|
||||
if item[0] == "image":
|
||||
if len(text_buffer) > 0:
|
||||
combined_item_list.append(("text", text_buffer))
|
||||
text_buffer = ""
|
||||
combined_item_list.append(item)
|
||||
previous_item_is_text = False
|
||||
else:
|
||||
if previous_item_is_text:
|
||||
text_buffer += item[1]
|
||||
else:
|
||||
text_buffer = item[1]
|
||||
previous_item_is_text = True
|
||||
if item_list[-1][0] != "image" and len(text_buffer) > 0:
|
||||
combined_item_list.append(("text", text_buffer))
|
||||
content = []
|
||||
for item in combined_item_list:
|
||||
item_type = item[0]
|
||||
if item_type == "text":
|
||||
content.append({
|
||||
"type": "text",
|
||||
"text": item[1]
|
||||
})
|
||||
elif item_type == "image":
|
||||
if isinstance(item[1], str):
|
||||
image_bytes = image_path_to_bytes(item[1])
|
||||
image_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
elif isinstance(item[1], np.ndarray):
|
||||
image = Image.fromarray(item[1]).convert("RGB")
|
||||
width, height = image.size
|
||||
image = image.resize((int(0.5*width), int(0.5*height)), Image.LANCZOS)
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, format='JPEG')
|
||||
image_bytes = image_bytes.getvalue()
|
||||
image_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_data}"
|
||||
},
|
||||
})
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
]
|
||||
return messages
|
||||
|
||||
def call_gpt_with_messages(messages, model_id="gpt-3.5-turbo", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
client = OpenAI() if not AZURE_ENDPOINT else AzureOpenAI(azure_endpoint = AZURE_ENDPOINT, api_key=OPENAI_API_KEY, api_version="2024-02-15-preview")
|
||||
num_attempts = 0
|
||||
while True:
|
||||
if num_attempts >= 10:
|
||||
raise ValueError("OpenAI request failed.")
|
||||
try:
|
||||
if any("image" in c["type"] for m in messages for c in m["content"]):
|
||||
payload = {
|
||||
"model": "gpt-4-turbo",
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
||||
return response.json()["choices"][0]["message"].get("content", "").strip()
|
||||
else:
|
||||
response = client.chat.completions.create(
|
||||
model=model_id,
|
||||
messages=messages if messages[0]["role"] == "system" else [{"role": "system", "content": system_prompt}] + messages,
|
||||
temperature=0.5,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
stop=None
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
except openai.AuthenticationError as e:
|
||||
print(e)
|
||||
return None
|
||||
except openai.RateLimitError as e:
|
||||
print(e)
|
||||
print("Sleeping for 10s...")
|
||||
time.sleep(10)
|
||||
num_attempts += 1
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Sleeping for 10s...")
|
||||
time.sleep(10)
|
||||
num_attempts += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
prompt = '''CURRENT OBSERVATION:
|
||||
RootWebArea [2634] 'My Account'
|
||||
link [3987] 'My Account'
|
||||
link [3985] 'My Wish List'
|
||||
link [3989] 'Sign Out'
|
||||
text 'Welcome to One Stop Market'
|
||||
link [3800] 'Skip to Content'
|
||||
link [3809] 'store logo'
|
||||
link [3996] 'My Cart'
|
||||
combobox [4190] 'Search' [required: False]
|
||||
link [4914] 'Advanced Search'
|
||||
button [4193] 'Search' [disabled: True]
|
||||
tablist [3699]
|
||||
tabpanel
|
||||
menu "[3394] 'Beauty & Personal Care'; [3459] 'Sports & Outdoors'; [3469] 'Clothing, Shoes & Jewelry'; [3483] 'Home & Kitchen'; [3520] 'Office Products'; [3528] 'Tools & Home Improvement'; [3533] 'Health & Household'; [3539] 'Patio, Lawn & Garden'; [3544] 'Electronics'; [3605] 'Cell Phones & Accessories'; [3620] 'Video Games'; [3633] 'Grocery & Gourmet Food'"
|
||||
main
|
||||
heading 'My Account'
|
||||
text 'Contact Information'
|
||||
text 'Emma Lopez'
|
||||
text 'emma.lopezgmail.com'
|
||||
link [3863] 'Change Password'
|
||||
text 'Newsletters'
|
||||
text "You aren't subscribed to our newsletter."
|
||||
link [3877] 'Manage Addresses'
|
||||
text 'Default Billing Address'
|
||||
group [3885]
|
||||
text 'Emma Lopez'
|
||||
text '101 S San Mateo Dr'
|
||||
text 'San Mateo, California, 94010'
|
||||
text 'United States'
|
||||
text 'T:'
|
||||
link [3895] '6505551212'
|
||||
text 'Default Shipping Address'
|
||||
group [3902]
|
||||
text 'Emma Lopez'
|
||||
text '101 S San Mateo Dr'
|
||||
text 'San Mateo, California, 94010'
|
||||
text 'United States'
|
||||
text 'T:'
|
||||
link [3912] '6505551212'
|
||||
link [3918] 'View All'
|
||||
table 'Recent Orders'
|
||||
row '| Order | Date | Ship To | Order Total | Status | Action |'
|
||||
row '| --- | --- | --- | --- | --- | --- |'
|
||||
row "| 000000170 | 5/17/23 | Emma Lopez | 365.42 | Canceled | View OrderReorder\tlink [4110] 'View Order'\tlink [4111] 'Reorder' |"
|
||||
row "| 000000189 | 5/2/23 | Emma Lopez | 754.99 | Pending | View OrderReorder\tlink [4122] 'View Order'\tlink [4123] 'Reorder' |"
|
||||
row "| 000000188 | 5/2/23 | Emma Lopez | 2,004.99 | Pending | View OrderReorder\tlink [4134] 'View Order'\tlink [4135] 'Reorder' |"
|
||||
row "| 000000187 | 5/2/23 | Emma Lopez | 1,004.99 | Pending | View OrderReorder\tlink [4146] 'View Order'\tlink [4147] 'Reorder' |"
|
||||
row "| 000000180 | 3/11/23 | Emma Lopez | 65.32 | Complete | View OrderReorder\tlink [4158] 'View Order'\tlink [4159] 'Reorder' |"
|
||||
link [4165] 'My Orders'
|
||||
link [4166] 'My Downloadable Products'
|
||||
link [4167] 'My Wish List'
|
||||
link [4169] 'Address Book'
|
||||
link [4170] 'Account Information'
|
||||
link [4171] 'Stored Payment Methods'
|
||||
link [4173] 'My Product Reviews'
|
||||
link [4174] 'Newsletter Subscriptions'
|
||||
heading 'Compare Products'
|
||||
text 'You have no items to compare.'
|
||||
heading 'My Wish List'
|
||||
text 'You have no items in your wish list.'
|
||||
contentinfo
|
||||
textbox [4177] 'Sign Up for Our Newsletter:' [required: False]
|
||||
button [4072] 'Subscribe'
|
||||
link [4073] 'Privacy and Cookie Policy'
|
||||
link [4074] 'Search Terms'
|
||||
link [4075] 'Advanced Search'
|
||||
link [4076] 'Contact Us'
|
||||
text 'Copyright 2013-present Magento, Inc. All rights reserved.'
|
||||
text 'Help Us Keep Magento Healthy'
|
||||
link [3984] 'Report All Bugs'
|
||||
Today is 6/12/2023. Base on the aforementioned webpage, tell me how many fulfilled orders I have over the past month, and the total amount of money I spent over the past month.'''
|
||||
print(call_gpt(prompt=prompt, model_id="gpt-4-turbo"))
|
41
AgentOccam/llms/llama.py
Normal file
41
AgentOccam/llms/llama.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
import boto3
|
||||
import json
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = '''You are an AI assistant. Your goal is to provide informative and substantive responses to queries.'''
|
||||
|
||||
def call_llama(prompt, model_id = "meta.llama3-8b-instruct-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
client = boto3.client("bedrock-runtime", region_name="us-east-1")
|
||||
|
||||
formatted_prompt = f'''\n<|begin_of_text|>\n<|start_header_id|>user<|end_header_id|>\n{system_prompt}\n{prompt}\n<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n'''
|
||||
|
||||
native_request = {
|
||||
"prompt": formatted_prompt,
|
||||
"max_gen_len": 512,
|
||||
"temperature": 0.5,
|
||||
}
|
||||
|
||||
request = json.dumps(native_request)
|
||||
|
||||
try:
|
||||
response = client.invoke_model(modelId=model_id, body=request)
|
||||
|
||||
except Exception as e:
|
||||
raise KeyError(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
|
||||
|
||||
model_response = json.loads(response["body"].read())
|
||||
|
||||
response_text = model_response["generation"]
|
||||
return response_text
|
||||
|
||||
def arrange_message_for_llama(item_list):
|
||||
for item in item_list:
|
||||
if item[0] == "image":
|
||||
raise NotImplementedError()
|
||||
prompt = "".join([item[1] for item in item_list])
|
||||
return prompt
|
||||
|
||||
def call_llama_with_messages(messages, model_id="meta.llama3-8b-instruct-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
return call_llama(prompt=messages, model_id=model_id, system_prompt=system_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(call_llama('''Hi'''))
|
42
AgentOccam/llms/mistral.py
Normal file
42
AgentOccam/llms/mistral.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import boto3
|
||||
import json
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = '''You are an AI assistant. Your goal is to provide informative and substantive responses to queries.'''
|
||||
|
||||
def call_mistral(prompt, model_id="mistral.mistral-large-2402-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
client = boto3.client("bedrock-runtime", region_name="us-east-1")
|
||||
|
||||
formatted_prompt = f"<s>[INST] {system_prompt}\n{prompt} [/INST]"
|
||||
|
||||
native_request = {
|
||||
"prompt": formatted_prompt,
|
||||
"max_tokens": 512,
|
||||
"temperature": 0.5,
|
||||
}
|
||||
|
||||
request = json.dumps(native_request)
|
||||
try:
|
||||
response = client.invoke_model(modelId=model_id, body=request)
|
||||
|
||||
except (ClientError, Exception) as e:
|
||||
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
|
||||
|
||||
model_response = json.loads(response["body"].read())
|
||||
|
||||
response_text = model_response["outputs"][0]["text"]
|
||||
return response_text
|
||||
|
||||
def arrange_message_for_mistral(item_list):
|
||||
for item in item_list:
|
||||
if item[0] == "image":
|
||||
raise NotImplementedError()
|
||||
prompt = "".join([item[1] for item in item_list])
|
||||
return prompt
|
||||
|
||||
def call_mistral_with_messages(messages, model_id="mistral.mistral-large-2402-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
return call_mistral(prompt=messages, model_id=model_id, system_prompt=system_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(call_mistral('''Hi'''))
|
||||
|
44
AgentOccam/llms/titan.py
Normal file
44
AgentOccam/llms/titan.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
import boto3
|
||||
import json
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = '''You are an AI assistant. Your goal is to provide informative and substantive responses to queries.'''
|
||||
|
||||
def call_titan(prompt, model_id="amazon.titan-text-premier-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
client = boto3.client("bedrock-runtime", region_name="us-east-1")
|
||||
|
||||
formatted_prompt = f"{system_prompt}\n{prompt}"
|
||||
|
||||
native_request = {
|
||||
"inputText": formatted_prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": 512,
|
||||
"temperature": 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
request = json.dumps(native_request)
|
||||
try:
|
||||
response = client.invoke_model(modelId=model_id, body=request)
|
||||
|
||||
except (ClientError, Exception) as e:
|
||||
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
|
||||
|
||||
model_response = json.loads(response["body"].read())
|
||||
|
||||
response_text = model_response["results"][0]["outputText"]
|
||||
return response_text
|
||||
|
||||
def arrange_message_for_titan(item_list):
|
||||
for item in item_list:
|
||||
if item[0] == "image":
|
||||
raise NotImplementedError()
|
||||
prompt = "".join([item[1] for item in item_list])
|
||||
return prompt
|
||||
|
||||
def call_titan_with_messages(messages, model_id="amazon.titan-text-premier-v1:0", system_prompt=DEFAULT_SYSTEM_PROMPT):
|
||||
return call_titan(prompt=messages, model_id=model_id, system_prompt=system_prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(call_titan('''Hi'''))
|
||||
|
410
AgentOccam/obs_opt.py
Normal file
410
AgentOccam/obs_opt.py
Normal file
|
@ -0,0 +1,410 @@
|
|||
import re
|
||||
from browser_env.processors import TreeNode
|
||||
from functools import partial
|
||||
|
||||
RETAINED_PROPERTIES = ["required", "disabled", "checked", "valuemin", "valuemax", "valuetext", "selected", "page_dialog_message"]
|
||||
UNWANTED_PROPERTIES = ["focused", "autocomplete", "hasPopup", "expanded", "multiselectable", "orientation", "controls"]
|
||||
UNINTERACTIVE_ROLES = ["StaticText", "LabelText", "main", "heading", "LayoutTable", "tabpanel", "LayoutTableRow", "LayoutTableCell", "time", "list", "contentinfo", "table", "row", "rowheader", "columnheader", "gridcell", "caption", "DescriptionList", "DescriptionListTerm", "DescriptionListDetail", "RootWebArea", "rowgroup", "alert"]
|
||||
ROLE_REPLACEMENT_DICT = {
|
||||
"StaticText": "text",
|
||||
"LabelText": "text",
|
||||
# "caption": "text",
|
||||
# "generic": "text"
|
||||
}
|
||||
|
||||
def parse_text_to_tree(text):
|
||||
lines = text.split('\n')
|
||||
|
||||
root = None
|
||||
parent_stack = {}
|
||||
|
||||
for line in lines:
|
||||
if line.strip() == "":
|
||||
continue
|
||||
line_strip = line.strip()
|
||||
line_parts = line_strip.split(' ')
|
||||
id = line_parts[0][1:-1]
|
||||
type = line_parts[1]
|
||||
text = ' '.join(line_parts[2:])
|
||||
level = 0
|
||||
for char in line:
|
||||
if char == '\t':
|
||||
level += 1
|
||||
else:
|
||||
break
|
||||
|
||||
node = TreeNode(id, type, text, level)
|
||||
|
||||
if line.startswith('\t'):
|
||||
parent_stack[level].add_child(node)
|
||||
else:
|
||||
root = node
|
||||
|
||||
parent_stack[level+1] = node
|
||||
|
||||
return root
|
||||
|
||||
def remove_unwanted_characters(text):
|
||||
text = text.replace('\xa0', ' ')
|
||||
cleaned_text = re.sub(r'[^\w\s,.!?;:\-\'\"()&/\u2019@]+', '', text, flags=re.UNICODE)
|
||||
cleaned_text = re.sub(r'\s+', ' ', cleaned_text)
|
||||
return cleaned_text.strip()
|
||||
|
||||
def search_node_by_id(node, target_id):
|
||||
if node.node_id == target_id:
|
||||
return node
|
||||
for child in node.children:
|
||||
result = search_node_by_id(child, target_id)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
def action_replace_node_role(node:TreeNode, role_replacement_dict:dict):
|
||||
if node.role in role_replacement_dict.keys():
|
||||
node.role = role_replacement_dict[node.role]
|
||||
|
||||
def action_remove_unwanted_characters(node:TreeNode):
|
||||
node.name = remove_unwanted_characters(node.name)
|
||||
|
||||
def action_remove_unwanted_properties(node:TreeNode):
|
||||
if node.has_properties():
|
||||
node.properties = {p: node.properties[p] for p in node.properties.keys() if p not in UNWANTED_PROPERTIES}
|
||||
if node.parent and node.parent.role=="row" and not node.properties["required"]:
|
||||
del node.properties["required"]
|
||||
if len(node.properties) == 0:
|
||||
node.properties = None
|
||||
|
||||
def action_remove_redundant_statictext_node(node:TreeNode):
|
||||
if not node.visible:
|
||||
return
|
||||
if not (node.all_children_invisible() and node.role in ["StaticText", "LabelText", "caption"]):
|
||||
return
|
||||
if (not node.name) or (node.parent and node.name in node.parent.name) or (node.parent and any(node.name in sibling.name for sibling in node.siblings())):
|
||||
node.visible = False
|
||||
|
||||
def action_merge_statictext_to_parent(node:TreeNode):
|
||||
if not node.visible:
|
||||
return
|
||||
if not (node.all_children_invisible() and node.role in ["StaticText", "LabelText", "caption"]):
|
||||
return
|
||||
if node.parent and not node.parent.name and len(node.parent.children) == 1:
|
||||
node.parent.name = node.name
|
||||
node.visible = False
|
||||
|
||||
def action_merge_menuitem_and_option(node:TreeNode):
|
||||
if not node.visible:
|
||||
return
|
||||
if not ((node.visible_children() and all(c.role=="menuitem" for c in node.visible_children())) or (node.visible_children() and all(c.role=="option" for c in node.visible_children()))):
|
||||
return
|
||||
if node.visible_children()[0].role == "menuitem":
|
||||
if not node.name.strip():
|
||||
node.name = "; ".join([action_return_visible_node(c).strip()[len("menuitem "):] for c in node.visible_children()])
|
||||
else:
|
||||
node.name += ": " + "; ".join([action_return_visible_node(c).strip()[len("menuitem "):] for c in node.visible_children()])
|
||||
elif node.visible_children()[0].role == "option":
|
||||
if not node.name.strip():
|
||||
node.name = "; ".join([action_return_visible_node(c).strip()[len("option "):] for c in node.visible_children()])
|
||||
else:
|
||||
node.name += ": " + "; ".join([action_return_visible_node(c).strip()[len("option "):] for c in node.visible_children()])
|
||||
for c in node.visible_children():
|
||||
c.visible = False
|
||||
|
||||
def action_merge_description_list(node:TreeNode):
|
||||
if not node.visible:
|
||||
return
|
||||
def reformat_sublist(current_list_term_buffer):
|
||||
if len(current_list_term_buffer) > 1:
|
||||
list_term_node_appended_name = []
|
||||
for n in current_list_term_buffer[1:]:
|
||||
list_term_node_appended_name.append(n.name)
|
||||
n.visible = False
|
||||
current_list_term_buffer[0].name += ": " + "; ".join(list_term_node_appended_name)
|
||||
|
||||
if not node.role == "DescriptionList":
|
||||
return
|
||||
for child in node.visible_children():
|
||||
if child.role == "DescriptionListDetail" and not child.name and len(child.visible_children()) == 1:
|
||||
child.name = action_return_visible_node(child.visible_children()[0]).strip()
|
||||
child.visible_children()[0].visible = False
|
||||
list_term_buffer = []
|
||||
for child in node.visible_children():
|
||||
if child.role == "DescriptionListTerm" and child.all_children_invisible():
|
||||
reformat_sublist(current_list_term_buffer=list_term_buffer)
|
||||
list_term_buffer = [child]
|
||||
elif child.role == "DescriptionListDetail" and child.all_children_invisible() and list_term_buffer:
|
||||
list_term_buffer.append(child)
|
||||
elif child.role == "DescriptionListDetail" and not child.all_children_invisible():
|
||||
list_term_buffer = []
|
||||
else:
|
||||
reformat_sublist(current_list_term_buffer=list_term_buffer)
|
||||
list_term_buffer = []
|
||||
reformat_sublist(current_list_term_buffer=list_term_buffer)
|
||||
|
||||
def action_remove_image(node:TreeNode):
|
||||
if not node.visible:
|
||||
return
|
||||
if node.all_children_invisible() and (node.role=="img" or node.name=="Image"):
|
||||
node.visible = False
|
||||
|
||||
def action_set_invisible(node:TreeNode):
|
||||
node.visible = False
|
||||
|
||||
def action_set_visible(node:TreeNode):
|
||||
node.visible = True
|
||||
|
||||
def action_set_visible_if_with_name(node:TreeNode):
|
||||
if node.name:
|
||||
node.visible = True
|
||||
|
||||
def action_reformat_table(node:TreeNode):
|
||||
if not node.visible:
|
||||
return
|
||||
def merge_gridcell(gridcell_node:TreeNode):
|
||||
if gridcell_node.role not in ["gridcell", "columnheader", "rowheader", "LayoutTableCell"] or not gridcell_node.visible:
|
||||
return
|
||||
gridcell_buffer = []
|
||||
parse_node_descendants(gridcell_node, action_return_visible_node, gridcell_buffer)
|
||||
if len(gridcell_buffer) == 1:
|
||||
return
|
||||
gridcell_buffer = [s.strip() for s in gridcell_buffer]
|
||||
if gridcell_node.name:
|
||||
gridcell_node.name += "\t" + "\t".join(gridcell_buffer[1:])
|
||||
else:
|
||||
gridcell_node.name = "\t".join(gridcell_buffer[1:])
|
||||
parse_node_descendants(gridcell_node, action_set_invisible)
|
||||
gridcell_node.visible = True
|
||||
|
||||
try:
|
||||
if node.role == "table":
|
||||
|
||||
def reformat_subtable(row_list, current_table_children):
|
||||
import copy
|
||||
new_table_children = copy.deepcopy(current_table_children)
|
||||
if row_list:
|
||||
# if row_list[0].children[0].role == "columnheader":
|
||||
if any(row_0_child.role == "columnheader" for row_0_child in row_list[0].children):
|
||||
if new_table_children and any(n.visible for n in new_table_children):
|
||||
new_table_children.append(TreeNode(node_id=row_list[0].node_id, role="row", name="", depth=row_list[0].depth))
|
||||
for i, row in enumerate(row_list):
|
||||
new_role_name = []
|
||||
for row_element in row.children:
|
||||
new_role_name.append(row_element.name)
|
||||
new_table_children.append(TreeNode(node_id=row.node_id, role="row", name="| "+" | ".join(new_role_name)+" |", depth=row.depth))
|
||||
if i == 0 and len(row_list) > 1:
|
||||
new_table_children.append(TreeNode(node_id=row.node_id, role="row", name="| "+" | ".join(["---"]*len(new_role_name))+" |", depth=row.depth))
|
||||
elif row_list[0].children[0].role == "rowheader":
|
||||
if new_table_children and any(n.visible for n in new_table_children):
|
||||
new_table_children.append(TreeNode(node_id=row_list[0].node_id, role="row", name="", depth=row_list[0].depth))
|
||||
titles = [r.children[0].name for r in row_list]
|
||||
values = [r.children[1].name for r in row_list]
|
||||
new_table_children.append(TreeNode(node_id=row_list[0].node_id, role="row", name="| "+" | ".join(titles)+" |", depth=row_list[0].depth))
|
||||
new_table_children.append(TreeNode(node_id=row_list[0].node_id, role="row", name="| "+" | ".join(["---"]*len(titles))+" |", depth=row_list[0].depth))
|
||||
new_table_children.append(TreeNode(node_id=row_list[0].node_id, role="row", name="| "+" | ".join(values)+" |", depth=row_list[0].depth))
|
||||
elif row_list[0].children[0].role == "gridcell":
|
||||
if new_table_children and any(n.visible for n in new_table_children):
|
||||
new_table_children.append(TreeNode(node_id=row_list[0].node_id, role="row", name="", depth=row_list[0].depth))
|
||||
for row in row_list:
|
||||
new_table_children.append(TreeNode(node_id=row.node_id, role="row", name="| "+" | ".join([row_element.name for row_element in row.children])+" |", depth=row.depth))
|
||||
else:
|
||||
raise NotImplementedError("Unrecognized table format.")
|
||||
return new_table_children
|
||||
|
||||
new_table_children = []
|
||||
row_list = []
|
||||
row_mode = False
|
||||
for child in node.children:
|
||||
if child.role == "row":
|
||||
for row_element in child.visible_children(): # TODO: Visible?
|
||||
merge_gridcell(row_element)
|
||||
|
||||
# if child.role == "row" and child.children[0].role == "columnheader":
|
||||
if child.role == "row" and any(row_child.role == "columnheader" for row_child in child.children):
|
||||
row_list = [child]
|
||||
row_mode = False
|
||||
elif child.role == "row" and child.children[0].role == "rowheader":
|
||||
if row_mode:
|
||||
row_list.append(child)
|
||||
else:
|
||||
new_table_children = reformat_subtable(row_list=row_list, current_table_children=new_table_children)
|
||||
row_list = [child]
|
||||
row_mode = True
|
||||
elif child.role == "row" and child.children[0].role == "gridcell":
|
||||
row_list.append(child)
|
||||
row_mode = False
|
||||
elif child.role != "row":
|
||||
new_table_children = reformat_subtable(row_list=row_list, current_table_children=new_table_children)
|
||||
if child.role == "rowgroup":
|
||||
for grandchild in child.visible_children(): # grandchild: row
|
||||
for row_element in grandchild.visible_children(): # TODO: Visible?
|
||||
merge_gridcell(row_element)
|
||||
child.children = reformat_subtable(row_list=child.children, current_table_children=[])
|
||||
new_table_children.append(child)
|
||||
row_list = []
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
new_table_children = reformat_subtable(row_list=row_list, current_table_children=new_table_children)
|
||||
node.children = new_table_children
|
||||
elif node.role == "LayoutTable":
|
||||
def merge_adjacent_text_nodes(nodes):
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
merged_nodes = []
|
||||
current_node = nodes[0]
|
||||
|
||||
for i in range(1, len(nodes)):
|
||||
if current_node.visible and current_node.role in ["LayoutTableCell", "StaticText", "generic"]+list(set(ROLE_REPLACEMENT_DICT.values())) and nodes[i].visible and nodes[i].role in ["LayoutTableCell", "StaticText", "generic"]+list(set(ROLE_REPLACEMENT_DICT.values())):
|
||||
current_node.role = ROLE_REPLACEMENT_DICT["StaticText"]
|
||||
current_node.name += " " + nodes[i].name # Merge text values
|
||||
nodes[i].visible = False
|
||||
else:
|
||||
merged_nodes.append(current_node)
|
||||
current_node = nodes[i]
|
||||
|
||||
merged_nodes.append(current_node)
|
||||
|
||||
return merged_nodes
|
||||
def dfs_merge_text(n:TreeNode):
|
||||
if not n.children:
|
||||
return
|
||||
for c in n.children:
|
||||
dfs_merge_text(c)
|
||||
n.children = merge_adjacent_text_nodes(n.children)
|
||||
if len(n.visible_children()) == 1 and n.visible_children()[0].role in ["LayoutTableCell", "StaticText", "generic"]+list(set(ROLE_REPLACEMENT_DICT.values())) and n.role in ["LayoutTableCell", "StaticText", "generic"]+list(set(ROLE_REPLACEMENT_DICT.values())):
|
||||
n.name += "\t" + n.visible_children()[0].name
|
||||
n.visible_children()[0].visible = False
|
||||
if n.role == "LayoutTableRow":
|
||||
for row_element in n.children:
|
||||
if row_element.visible and row_element.children:
|
||||
for sub_element in row_element.children:
|
||||
if sub_element.visible:
|
||||
node_str = action_return_visible_node(sub_element).strip()
|
||||
row_element.name += f"\t{node_str}"
|
||||
row_element.children = []
|
||||
n.name = "| " + " | ".join([c.name for c in n.children if c.visible]) + " |" # TODO: Visible?
|
||||
for row_element in n.children:
|
||||
row_element.visible = False
|
||||
dfs_merge_text(node)
|
||||
except Exception as e:
|
||||
print("Table reformatting error:", e)
|
||||
|
||||
def action_merge_duplicated_headings(node:TreeNode):
|
||||
if not node.visible or not node.all_children_invisible() or not node.parent or node.visible_siblings():
|
||||
return
|
||||
if node.role=="heading" and node.parent.role not in UNINTERACTIVE_ROLES and node.name == node.parent.name:
|
||||
node.visible = False
|
||||
if node.parent.role=="heading" and node.role not in UNINTERACTIVE_ROLES and node.name == node.parent.name:
|
||||
node.parent.node_id = node.node_id
|
||||
node.parent.role = node.role
|
||||
node.parent.properties = node.properties
|
||||
node.parent.children = node.children
|
||||
node.visible = False
|
||||
|
||||
def action_print_tree(node:TreeNode):
|
||||
print("\t" * node.depth + f"{node.visible} {node.depth} [{node.node_id}] {node.role}: {node.name}")
|
||||
|
||||
def action_return_visible_node(node:TreeNode, intent_bias=0, mode="concise", **kwargs):
|
||||
if not node.visible:
|
||||
return None
|
||||
if mode == "concise":
|
||||
node_str = node.role
|
||||
hidden_roles = UNINTERACTIVE_ROLES+list(set(ROLE_REPLACEMENT_DICT.values()))
|
||||
if "[" in node.name and "hidden_roles" in kwargs.keys():
|
||||
hidden_roles += kwargs["hidden_roles"]
|
||||
if node.role not in hidden_roles:
|
||||
node_str += f" [{node.node_id}]"
|
||||
elif mode == "verbose":
|
||||
node_str = f"{node.role} [{node.node_id}]"
|
||||
elif mode == "name_only":
|
||||
node_str = node.role
|
||||
elif mode == "name_retained_id_only":
|
||||
node_str = node.role
|
||||
retained_ids = kwargs.get("retained_ids", [])
|
||||
if node.node_id in retained_ids:
|
||||
node_str += f" [{node.node_id}]"
|
||||
|
||||
if node.name:
|
||||
node_str += f" {repr(node.name)}"
|
||||
if node.has_properties():
|
||||
for p in node.properties:
|
||||
p_value = node.properties[p]
|
||||
node_str += f" [{p}: {p_value}]"
|
||||
return "\t" * (node.depth-intent_bias) + node_str
|
||||
|
||||
def parse_node_siblings(node:TreeNode, action=action_print_tree, tree_buffer=[]):
|
||||
for sibling in node.siblings():
|
||||
res_action = action(sibling)
|
||||
if res_action:
|
||||
tree_buffer.append(res_action)
|
||||
|
||||
def parse_node_ancestors(node:TreeNode, action=action_print_tree, tree_buffer=[]):
|
||||
res_action = action(node)
|
||||
if res_action:
|
||||
tree_buffer.append(res_action)
|
||||
if node.parent:
|
||||
parse_node_ancestors(node=node.parent, action=action, tree_buffer=tree_buffer)
|
||||
|
||||
def parse_node_descendants(node:TreeNode, action=action_print_tree, tree_buffer=[]):
|
||||
res_action = action(node)
|
||||
if res_action:
|
||||
tree_buffer.append(res_action)
|
||||
for child in node.children:
|
||||
parse_node_descendants(node=child, action=action, tree_buffer=tree_buffer)
|
||||
|
||||
def prune_tree_fuzzy_node(node:TreeNode): # TODO: Bugs!!!
|
||||
if not node.children:
|
||||
return
|
||||
|
||||
# Iterate over the children in reverse order to safely remove nodes
|
||||
fuzzy_children = []
|
||||
for child in reversed(node.children):
|
||||
prune_tree_fuzzy_node(child)
|
||||
if child.all_children_invisible() and not child.is_differentiable(strict=True):
|
||||
fuzzy_children.append(child)
|
||||
for child in fuzzy_children:
|
||||
child.visible = False
|
||||
|
||||
def translate_node_to_str(node: TreeNode, mode="concise", **kwargs):
|
||||
tree_buffer = []
|
||||
parse_node_descendants(node, partial(action_return_visible_node, intent_bias=node.depth, mode=mode, **kwargs), tree_buffer=tree_buffer)
|
||||
return "\n".join(tree_buffer[:1000])
|
||||
|
||||
def construct_new_DOM_with_visible_nodes(DOM_root:TreeNode):
|
||||
def dfs(node:TreeNode):
|
||||
if not node.visible:
|
||||
return None
|
||||
if not node.visible_children():
|
||||
return node.copy()
|
||||
new_self = node.copy()
|
||||
for child in node.visible_children():
|
||||
new_child = dfs(child)
|
||||
if new_child:
|
||||
new_self.add_child(new_child)
|
||||
return new_self
|
||||
new_DOM_Root = dfs(DOM_root)
|
||||
return new_DOM_Root
|
||||
|
||||
def prune_tree(objective, root_node, mode="str"):
|
||||
root_node_copy = construct_new_DOM_with_visible_nodes(root_node)
|
||||
parse_node_descendants(root_node_copy, action_remove_unwanted_characters)
|
||||
parse_node_descendants(root_node_copy, action_remove_unwanted_properties)
|
||||
parse_node_descendants(root_node_copy, action_remove_redundant_statictext_node)
|
||||
parse_node_descendants(root_node_copy, action_remove_image)
|
||||
prune_tree_fuzzy_node(root_node_copy)
|
||||
parse_node_descendants(root_node_copy, action_remove_image)
|
||||
parse_node_descendants(root_node_copy, action_merge_statictext_to_parent)
|
||||
parse_node_descendants(root_node_copy, action_remove_redundant_statictext_node)
|
||||
parse_node_descendants(root_node_copy, partial(action_replace_node_role, role_replacement_dict=ROLE_REPLACEMENT_DICT))
|
||||
parse_node_descendants(root_node_copy, action_merge_menuitem_and_option)
|
||||
parse_node_descendants(root_node_copy, action_merge_description_list)
|
||||
parse_node_descendants(root_node_copy, action_reformat_table)
|
||||
parse_node_descendants(root_node_copy, action_merge_duplicated_headings)
|
||||
|
||||
if mode == "str":
|
||||
browser_content = translate_node_to_str(node=root_node_copy, mode="concise")
|
||||
elif mode == "node":
|
||||
browser_content = construct_new_DOM_with_visible_nodes(root_node_copy)
|
||||
return browser_content
|
||||
|
||||
def contains_keyword(title, keyword):
|
||||
return keyword in title.lower()
|
291
AgentOccam/plot.py
Normal file
291
AgentOccam/plot.py
Normal file
|
@ -0,0 +1,291 @@
|
|||
import os
|
||||
import csv
|
||||
import json
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
import numpy as np
|
||||
|
||||
from AgentOccam.utils import COLOR_DICT, TASK_ID_DICT, MERGED_SITE_TASK_ID_DICT, EVELUATOR_RECTIFICATIONS, RUN_NAME_DICT, TASK_LABELS_MULTISITE, TRAJECTORY_DIR_DICT, OUTPUT_DIR, TOTAL_TASK_NUM_DICT
|
||||
|
||||
|
||||
def random_color_generator():
|
||||
import random
|
||||
random.seed(65)
|
||||
while True:
|
||||
r = random.randint(0, 255)
|
||||
g = random.randint(0, 255)
|
||||
b = random.randint(0, 255)
|
||||
yield f'#{r:02X}{g:02X}{b:02X}'
|
||||
|
||||
def generate_random_colors(color_num):
|
||||
colors = [next(random_color_generator) for _ in range(color_num)]
|
||||
return colors
|
||||
|
||||
def get_colors(trajectory_key_list):
|
||||
return [COLOR_DICT[k] if k in COLOR_DICT else next(random_color_generator) for k in trajectory_key_list]
|
||||
|
||||
def parse_summary_csv_files(root_dir, site_list, mode="single_site"):
|
||||
total_reward = 0
|
||||
total_tasks = 0
|
||||
net_total_reward = 0
|
||||
|
||||
id_list = []
|
||||
for site in site_list:
|
||||
if mode == "multiple_site":
|
||||
id_list += TASK_ID_DICT[site]
|
||||
elif mode == "single_site":
|
||||
id_list += MERGED_SITE_TASK_ID_DICT[site]
|
||||
|
||||
for subdir, _, files in os.walk(root_dir):
|
||||
for file in files:
|
||||
if file == 'summary.csv':
|
||||
filepath = os.path.join(subdir, file)
|
||||
with open(filepath, 'r') as csv_file:
|
||||
csv_reader = csv.DictReader(csv_file)
|
||||
for row in csv_reader:
|
||||
task_id = int(row['task_id'])
|
||||
if task_id in id_list:
|
||||
total_tasks += 1
|
||||
total_reward += float(row['reward'])
|
||||
net_total_reward += 1 if float(row['reward']) == 1. else 0
|
||||
|
||||
if total_tasks > 0:
|
||||
return total_reward, net_total_reward, total_tasks
|
||||
else:
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def parse_json_files(root_dir, site_list, evaluator="after", mode="single_site"):
|
||||
total_reward = 0
|
||||
total_tasks = 0
|
||||
net_total_reward = 0
|
||||
|
||||
id_list = []
|
||||
for site in site_list:
|
||||
if mode == "multiple_site":
|
||||
id_list += TASK_ID_DICT[site]
|
||||
elif mode == "single_site":
|
||||
id_list += MERGED_SITE_TASK_ID_DICT[site]
|
||||
|
||||
for filename in os.listdir(root_dir):
|
||||
if filename.endswith(".json"):
|
||||
try:
|
||||
trajectory_obj = json.load(open(os.path.join(root_dir, filename), "r"))
|
||||
if trajectory_obj["id"] in id_list:
|
||||
if (evaluator=="before" and trajectory_obj["id"] not in EVELUATOR_RECTIFICATIONS) or evaluator=="after":
|
||||
if "trajectory" in trajectory_obj.keys():
|
||||
last_step = trajectory_obj["trajectory"][-1]
|
||||
reward = float(last_step['reward']) if "reward" in last_step.keys() else last_step['success']
|
||||
else:
|
||||
reward = trajectory_obj["score"]
|
||||
total_tasks += 1
|
||||
total_reward += reward
|
||||
net_total_reward += 1 if reward == 1. else 0
|
||||
except Exception as e:
|
||||
print(os.path.join(root_dir, filename))
|
||||
print(e)
|
||||
|
||||
if total_tasks > 0:
|
||||
return total_reward, net_total_reward, total_tasks
|
||||
else:
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def find_summary_csv_files(directories):
|
||||
summary_files = []
|
||||
for directory in directories:
|
||||
for root, _, files in os.walk(directory):
|
||||
for file in files:
|
||||
if file == 'summary.csv':
|
||||
summary_files.append(os.path.join(root, file))
|
||||
return summary_files
|
||||
|
||||
def read_rewards_with_dir_names(summary_files):
|
||||
rewards_with_dirs = {}
|
||||
for file in summary_files:
|
||||
directory_name = os.path.basename(os.path.dirname(file))
|
||||
df = pd.read_csv(file)
|
||||
if 'reward' in df.columns:
|
||||
rewards_with_dirs[directory_name] = df['reward'].tolist()
|
||||
return rewards_with_dirs
|
||||
|
||||
def write_rewards_to_csv(rewards, output_file):
|
||||
with open(output_file, 'w') as f:
|
||||
f.write('reward\n')
|
||||
for reward in rewards:
|
||||
f.write(f'{reward}\n')
|
||||
|
||||
def load_reward(root_dir, evaluator="after"):
|
||||
reward_dict = {}
|
||||
net_reward_dict = {}
|
||||
for filename in os.listdir(root_dir):
|
||||
if filename.endswith(".json"):
|
||||
trajectory_obj = json.load(open(os.path.join(root_dir, filename), "r"))
|
||||
trajectory_id = trajectory_obj["id"]
|
||||
if (evaluator=="before" and trajectory_obj["id"] not in EVELUATOR_RECTIFICATIONS) or evaluator=="after":
|
||||
if "trajectory" in trajectory_obj.keys():
|
||||
last_step = trajectory_obj["trajectory"][-1]
|
||||
reward_dict[trajectory_id] = float(last_step['reward']) if "reward" in last_step.keys() else last_step['success']
|
||||
else:
|
||||
reward_dict[trajectory_id] = float(trajectory_obj["score"])
|
||||
net_reward_dict[trajectory_id] = 1. if reward_dict[trajectory_id] == 1. else 0.
|
||||
reward_list = []
|
||||
net_reward_list = []
|
||||
print("\n"+root_dir)
|
||||
for i in range(812):
|
||||
if i in reward_dict.keys():
|
||||
reward_list.append(reward_dict[i])
|
||||
else:
|
||||
print(f"{i},", end="")
|
||||
# reward_list.append(-1)
|
||||
reward_list.append(0)
|
||||
if i in net_reward_dict.keys():
|
||||
net_reward_list.append(net_reward_dict[i])
|
||||
else:
|
||||
# net_reward_list.append(-1)
|
||||
net_reward_list.append(0)
|
||||
return reward_list, net_reward_list
|
||||
|
||||
def compare_rewards(trajectory_key_list=None, evaluator="after"):
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
basenames = [RUN_NAME_DICT[k] for k in trajectory_key_list]
|
||||
|
||||
tasks = list(range(812))
|
||||
labels = TASK_LABELS_MULTISITE
|
||||
rewards = [load_reward(TRAJECTORY_DIR_DICT[k], evaluator=evaluator)[1] for k in trajectory_key_list]
|
||||
|
||||
label_list = []
|
||||
label_index_dict = {}
|
||||
for i, label in enumerate(labels):
|
||||
if label not in label_list:
|
||||
label_list.append(label)
|
||||
label_index_dict[label] = []
|
||||
label_index_dict[label].append(i)
|
||||
sorted_index_list = []
|
||||
for label in label_list:
|
||||
sorted_index_list += label_index_dict[label]
|
||||
tasks = [tasks[i] for i in sorted_index_list]
|
||||
labels = [labels[i] for i in sorted_index_list]
|
||||
for i in range(len(rewards)):
|
||||
rewards[i] = [int(rewards[i][j]) for j in sorted_index_list]
|
||||
|
||||
data = {
|
||||
'Task': tasks,
|
||||
'Site': labels,
|
||||
**{basename: reward for basename, reward in zip(basenames, rewards)}
|
||||
}
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
csvfile = open(os.path.join(OUTPUT_DIR, "compare.csv"), "w")
|
||||
csv_writer = csv.writer(csvfile)
|
||||
csv_writer.writerow(["task", "site"]+basenames)
|
||||
for i, reward in enumerate(zip(*tuple(rewards))):
|
||||
csv_writer.writerow([df['Task'][i], df['Site'][i]]+list(reward))
|
||||
|
||||
def plot_comparative_heatmap():
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
file_path = os.path.join(OUTPUT_DIR, 'compare.csv')
|
||||
data = pd.read_csv(file_path)
|
||||
|
||||
for site in ["shopping_admin", "shopping", "reddit", "gitlab", "map", "multisite"]:
|
||||
site_data = data[data['site'] == site]
|
||||
approach_keys = [k for k in site_data.keys() if k not in ["task", "site"]]
|
||||
|
||||
heatmap_data = pd.DataFrame({
|
||||
k: site_data[k] for k in approach_keys
|
||||
})
|
||||
|
||||
heatmap_values = heatmap_data.values
|
||||
|
||||
colors = ['#EFEFEF', '#2A786C']
|
||||
cmap = mcolors.LinearSegmentedColormap.from_list("CustomCmap", colors)
|
||||
plt.figure(figsize=(10, 20))
|
||||
plt.imshow(heatmap_values, cmap=cmap, aspect='auto')
|
||||
|
||||
plt.xticks(ticks=[0.5 + k for k in list(range(len(approach_keys)))], labels=[]*len(approach_keys))
|
||||
plt.yticks([])
|
||||
|
||||
ax = plt.gca()
|
||||
|
||||
ax.set_yticks([])
|
||||
|
||||
ax_left = plt.gca().twinx()
|
||||
ax_left.set_yticks(np.arange(len(site_data))+1)
|
||||
ax_left.set_yticklabels(site_data.iloc[::-1]["task"], fontsize=3)
|
||||
|
||||
ax_right = plt.gca().twinx()
|
||||
ax_right.set_yticks(np.arange(len(site_data))+1)
|
||||
ax_right.set_yticklabels(site_data.iloc[::-1]["task"], fontsize=3)
|
||||
ax_right.yaxis.set_label_position("right")
|
||||
|
||||
plt.grid(color='white', linestyle='-', linewidth=5)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(OUTPUT_DIR, f"figures/{site}_{len(approach_keys)}.png"), dpi=256)
|
||||
|
||||
def plot_comparative_bar_chart(categories, data_list, labels, colors, title="Comparative Bar Chart", ylabel="Values", figure_name="bar"):
|
||||
os.makedirs(os.path.join(OUTPUT_DIR, "figures"), exist_ok=True)
|
||||
|
||||
bar_width = 1/(len(labels)+1)
|
||||
x = np.arange(len(categories))
|
||||
|
||||
plt.rc('font', family='serif')
|
||||
plt.figure(figsize=(9, 2))
|
||||
|
||||
for i, (data, label, color) in enumerate(zip(data_list, labels, colors)):
|
||||
plt.bar(x + i * bar_width, data, width=bar_width, label=label, color=color)
|
||||
|
||||
for i, (data, label) in enumerate(zip(data_list, labels)):
|
||||
for j, value in enumerate(data):
|
||||
plt.text(x[j] + i * bar_width, value, f"{value:.1f}" if isinstance(value, float) else f"{value}", ha='center', va='bottom', fontsize=5)
|
||||
|
||||
if title:
|
||||
plt.title(title)
|
||||
plt.ylabel(ylabel, fontsize=11)
|
||||
plt.xticks(x + bar_width * (len(labels) - 1) / 2, [c.replace("_", " ").capitalize() for c in categories], fontsize=11)
|
||||
plt.legend(loc='lower center', fontsize=11, bbox_to_anchor=(0.5, 1.05), ncol=3)
|
||||
plt.grid(axis='y')
|
||||
|
||||
plt.ylim(0, 65)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(OUTPUT_DIR, f"figures/{figure_name}.pdf"), dpi=256)
|
||||
plt.close()
|
||||
|
||||
def compute_success_rate(trajectory_key_list=None, evaluator="after"):
|
||||
site_lists = ["ALL", "SHOPPING", "SHOPPING_ADMIN", "GITLAB", "MAP", "REDDIT", "MULTISITE"]
|
||||
csvfile = open(os.path.join(OUTPUT_DIR, "result.csv"), "w")
|
||||
csv_writer = csv.writer(csvfile)
|
||||
csv_writer.writerow(["basename", "site", "total_reward", "net_total_reward", "total_tasks"])
|
||||
|
||||
categories = site_lists
|
||||
|
||||
trajectory_key_list = trajectory_key_list if trajectory_key_list else [k for k in sorted(list(TRAJECTORY_DIR_DICT.keys()), reverse=False)]
|
||||
labels = [RUN_NAME_DICT[i] for i in trajectory_key_list]
|
||||
|
||||
colors = get_colors(trajectory_key_list)
|
||||
|
||||
reward_percentage_list = {l:[] for l in labels}
|
||||
net_reward_percentage_list = {l:[] for l in labels}
|
||||
|
||||
for i, key in enumerate(trajectory_key_list):
|
||||
root_directory = TRAJECTORY_DIR_DICT[key]
|
||||
basename = labels[i]
|
||||
for site_list in site_lists:
|
||||
total_reward, net_total_reward, total_tasks = parse_json_files(root_directory, [site_list], evaluator=evaluator, mode="multiple_site")
|
||||
total_tasks = TOTAL_TASK_NUM_DICT[site_list]
|
||||
reward_percentage_list[basename].append(total_reward/total_tasks*100)
|
||||
net_reward_percentage_list[basename].append(net_total_reward/total_tasks*100)
|
||||
csv_writer.writerow([basename, site_list, total_reward, net_total_reward, total_tasks])
|
||||
csvfile.close()
|
||||
plot_comparative_bar_chart(categories=categories, data_list=[reward_percentage_list[l] for l in labels], labels=labels, colors=colors, title="Reward Percentage", figure_name="reward_percentage")
|
||||
plot_comparative_bar_chart(categories=categories, data_list=[net_reward_percentage_list[l] for l in labels], labels=labels, colors=colors, title="", ylabel="Success Rate", figure_name="net_reward_percentage")
|
||||
|
||||
if __name__ == "__main__":
|
||||
ablation_study_key_list = [7, 3, 4, 5, 6, 0]
|
||||
compute_success_rate(ablation_study_key_list)
|
92
AgentOccam/prompts/AgentOccam_prompt.py
Normal file
92
AgentOccam/prompts/AgentOccam_prompt.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
actor = {
|
||||
"instruction_template": {
|
||||
"with_planning": '''You are an AI assistant performing tasks on a web browser. You will be provided with task objective, current step, web page observations, previous plans, and interaction history. You need to issue an action for this step.
|
||||
|
||||
Generate the response in the following format:
|
||||
{output_specifications}
|
||||
|
||||
You are ONLY allowed to use the following action commands. Strictly adheres to the given format. Only issue one single action.
|
||||
If you think you should refine the plan, use the following actions:
|
||||
{planning_specifications}
|
||||
Otherwise, use the following actions:
|
||||
{navigation_specifications}''',
|
||||
|
||||
"without_planning": '''You are an AI assistant performing tasks on a web browser. You will be provided with task objective, current step, web page observations, and other relevant information. You need to issue an action for this step.
|
||||
|
||||
Generate the response in the following format:
|
||||
{output_specifications}
|
||||
|
||||
You are ONLY allowed to use the following action commands. Strictly adheres to the given format. Only issue one single action.
|
||||
{navigation_specifications}'''
|
||||
},
|
||||
|
||||
"input_template":'''{input}''',
|
||||
|
||||
"QA": {
|
||||
"instruction_template": '''You are a proficient assistant good at answering web page related questions. Given the web page textual description, you are required to answer the question.
|
||||
|
||||
Generate the response in the following format:
|
||||
RESPONSE:
|
||||
Your response here.
|
||||
|
||||
Adhere to the following response requirements:
|
||||
* If you are not fully sure that you can answer the question correcly with the information given, only take note of crucial relevant information.
|
||||
* Otherwise, if you are confident about the answer, return your full answer. Ensure that your response is correct and comprehensive that fully explain your conclusion.''',
|
||||
"input_template": '''WEB PAGE CONTENT:
|
||||
{current_observation}
|
||||
|
||||
QUESTION:
|
||||
{objective}'''
|
||||
},
|
||||
|
||||
"planning": {
|
||||
"instruction_template": '''You are an AI assistant performing tasks on a web browser. You will be provided with task objective, current step, url, web page observations, previous plans, and actions. You need to issue a plan for this step.
|
||||
|
||||
Generate the response in the following format:
|
||||
{output_specifications}
|
||||
|
||||
You are ONLY allowed to use the following planning commands. Strictly adheres to the given format. Only issue one single planning command.
|
||||
{planning_specifications}''',
|
||||
"input_template": ''''''
|
||||
},
|
||||
|
||||
"reflection": {
|
||||
"instruction_template": '''You are an AI assistant performing tasks on a web browser. You will be provided with task objective, current step, url, web page observations, previous plans, and actions. You need to reflect on past mistakes, take corrective action, and maximize future rewards.
|
||||
|
||||
Generate the response in the following format:
|
||||
{output_specifications}
|
||||
|
||||
You are ONLY allowed to use the following action commands. Strictly adheres to the given format. Only issue one single action.
|
||||
If you think you should refine the plan, use the following actions:
|
||||
{planning_specifications}
|
||||
Otherwise, use the following actions:
|
||||
{navigation_specifications}''',
|
||||
"input_template": ''''''
|
||||
},
|
||||
}
|
||||
critic = {
|
||||
|
||||
"harsh": {"instruction_template": '''Below are the objective (high-level goal) and corresponding web observations and actions I took to navigate the web and achieve the goal, which has proven to be **unsuccessful**. As the objective is fully achievable within the current environment, I am expecting skeptical feedback on why I failed based on my interaction history and the current state.
|
||||
|
||||
Adhere to the following output format:
|
||||
{output_specifications}''',
|
||||
|
||||
|
||||
"input_template": '''The following is all my interaction history and current state:
|
||||
{input}'''},
|
||||
|
||||
"normal": {
|
||||
"instruction_template": '''You are a seasoned web navigator. You now assess the performance of another web navigation agent based on the objective, their previous interaction history and the web's current state.\nAdhere to the following output format:\n{output_specifications}''',
|
||||
"input_template": '''The following is all my interaction history and current state:\n{input}''',
|
||||
}
|
||||
|
||||
}
|
||||
judge = {
|
||||
"instruction_template": '''You are a seasoned web navigator. You now assess the value and risk of serveral web navigation actions based on the objective, the previous interaction history and the web's current state. Then, you select the action with the most value and least risk with which you would earn the maximum objective fulfillment reward in the future.
|
||||
|
||||
Adhere to the following output format:
|
||||
{output_specifications}
|
||||
|
||||
Note that `branch` and `prune` are planning actions that will modify the PREVIOUS PLAN section and won't interact with the web environment.''',
|
||||
"input_template": '''The following is the interaction history, current state, and action choices.\n{input}'''
|
||||
}
|
1
AgentOccam/prompts/navigation_specifications/click.txt
Normal file
1
AgentOccam/prompts/navigation_specifications/click.txt
Normal file
|
@ -0,0 +1 @@
|
|||
click [id]: To click on an element with its numerical ID on the webpage. E.g., `click [7]` If clicking on a specific element doesn't trigger the transition to your desired web state, this is due to the element's lack of interactivity or GUI visibility. In such cases, move on to interact with OTHER similar or relevant elements INSTEAD.
|
1
AgentOccam/prompts/navigation_specifications/go_back.txt
Normal file
1
AgentOccam/prompts/navigation_specifications/go_back.txt
Normal file
|
@ -0,0 +1 @@
|
|||
go_back: To return to the previously viewed page.
|
1
AgentOccam/prompts/navigation_specifications/go_home.txt
Normal file
1
AgentOccam/prompts/navigation_specifications/go_home.txt
Normal file
|
@ -0,0 +1 @@
|
|||
go_home: To return to the homepage where you can find other websites.
|
1
AgentOccam/prompts/navigation_specifications/note.txt
Normal file
1
AgentOccam/prompts/navigation_specifications/note.txt
Normal file
|
@ -0,0 +1 @@
|
|||
note [content]: To take note of all important info w.r.t. completing the task to enable reviewing it later. E.g., `note [Spent $10 on 4/1/2024]`
|
1
AgentOccam/prompts/navigation_specifications/scroll.txt
Normal file
1
AgentOccam/prompts/navigation_specifications/scroll.txt
Normal file
|
@ -0,0 +1 @@
|
|||
scroll [down/up] [reason]: To navigate the webpage content. E.g., `scroll [up] [Previous observations contain a link that might be useful.]`
|
1
AgentOccam/prompts/navigation_specifications/stop.txt
Normal file
1
AgentOccam/prompts/navigation_specifications/stop.txt
Normal file
|
@ -0,0 +1 @@
|
|||
stop [answer]: To stop interaction and return response. Present your answer within the brackets. If the task doesn't require a textual answer or appears insurmountable, indicate "N/A" and additional reasons and all relevant information you gather as the answer. E.g., `stop [5h 47min]`
|
1
AgentOccam/prompts/navigation_specifications/type.txt
Normal file
1
AgentOccam/prompts/navigation_specifications/type.txt
Normal file
|
@ -0,0 +1 @@
|
|||
type [id] [content] [press_enter_after=0|1]: To type content into a field with a specific ID. By default, the "Enter" key is pressed after typing unless `press_enter_after` is set to 0. E.g., `type [15] [Carnegie Mellon University] [1]` If you can't find what you're looking for on your first attempt, consider refining your search keywords by breaking them down or trying related terms.
|
1
AgentOccam/prompts/output_specifications/action.txt
Normal file
1
AgentOccam/prompts/output_specifications/action.txt
Normal file
|
@ -0,0 +1 @@
|
|||
Select your action here.
|
|
@ -0,0 +1 @@
|
|||
Assess the value and risk of each action. Consider both the best-case and worst-case outcomes resulting from its implementation. Itemize the assessment using this format: `- action [action_id]: [action value, including but not limited to what outcomes you can expect by executing the action, or whether the note is of the most correct and comprehensive content] [action risk, including but not limited to whether the note/stop content is correct, and whether you can gather more information by continuing playing rather than ending the trial] [{best_case}] [{worst_case}]`.
|
|
@ -0,0 +1 @@
|
|||
Propose ALL potential actions at this step. Itemize the actions using this format: `- reason: [{reason_for_proposing_the_following_action0}]\n- action: [{action0_command}]\n\n- reason: [{reason_for_proposing_the_following_action1}]\n- action: [{action1_command}]\n\n...`.
|
|
@ -0,0 +1 @@
|
|||
List the numerical id of your selected action here. You can only choose one action. E.g., `1`.
|
|
@ -0,0 +1 @@
|
|||
Emphasize all important details in the INTERACTION HISTORY section.
|
26
AgentOccam/prompts/output_specifications/mistakes.txt
Normal file
26
AgentOccam/prompts/output_specifications/mistakes.txt
Normal file
|
@ -0,0 +1,26 @@
|
|||
Point out the major mistakes of previous steps by ONLY using the following templates:
|
||||
- You have make a reasoning mistake by "{quote}". The correct reasoning should be "{correction}".
|
||||
- You should check the "{link_name}" link first.
|
||||
- You should know that the recent order table doesn't include all previous orders. Don't hush to a conclusion.
|
||||
- You have missed important details on this page: {details}.
|
||||
- I don't think your answer follow the task requirements. That's a fault I wouldn't expect. Reconsider seriously.
|
||||
- You have employed different approaches/the same approach many times to do the task but failed. The task assigner might just want to challenge you to answer no and there might be no answer for this brain teaser question.
|
||||
- If the task ask for the most extreme case (e.g., with highest price), I suggest you sort them by that key first.
|
||||
- If there are multiple requirements for an item, break down the requirements and search them one by one.
|
||||
- The active plan is a complex task. Don't rush. Further break down the task by using the planning commands.
|
||||
- There might be multiple relevant orders to check before reach the conclusion. First, view ALL previous orders to finalize the order checklist and take notes of orders to be checked with `note [note_content]` command while viewing. Second, view the order details one by one and take notes of all crucial information. Finally, view all notes and think step by step before concluding the answer.
|
||||
- You have reasoned too much in one step which leads to errors. Break down the task with planning.
|
||||
- You should change the "selected" state of the items in the combobox.
|
||||
- From my observation and consideration, I suggest you conclude the task as there's no answer even though you have tried multiple times with different approaches.
|
||||
- When the task mentioned "category", it imples you can navigate to that category by selecting menus step by step. Select the most relevant first and the subcategories would appear. Select the appropriate subcategory then.
|
||||
- You have not gone over all the reviews, {review_page_num} pages in total.
|
||||
- You have not gone over all the items, {item_page_num} pages in total.
|
||||
- Don't take the same notes multiple times.
|
||||
- You should select and click the radio (required field) first.
|
||||
- You should go over all relevant items and take notes of all crucial information with `note [note_content]`. Then finalize your choice by carefully consider based on your notes.
|
||||
- Don't submit yet. Just show the form completion page. Retry.
|
||||
- You missed a required field before submission, which leads to the failure of your last attempt. Retry.
|
||||
- Canceled Orders and pending orders are not fulfilled orders.
|
||||
- There are {order_num} relevant orders on this page, which is/are {order_ids}. You have viewed {order_ids} and taken notes, and {order_ids} still requires reviewing and taking notes.
|
||||
- You have gone over all review/item/order pages.
|
||||
- Except when keywords "category", "subcategories", etc are specifically mentioned in the objective, the fastest way to find items is to use the `search` feature.
|
|
@ -0,0 +1 @@
|
|||
Describe information in the CURRENT OBSERVATION section. Emphasize elements and features that are relevant or potentially helpful for fulfilling the objective in detail.
|
|
@ -0,0 +1 @@
|
|||
List the numerical ids of elements on the current webpage based on which you would issue your action. Also include elements on the current webpage you would attend to if you fail in the future and have to restore to this step. Don't include elements from the previous pages. Select elements at a higher hierarchical level if most their children nodes are considered crucial. Sort by relevance and potential values from high to low, and separate the ids with commas. E.g., `1321, 52, 756, 838`.
|
|
@ -0,0 +1 @@
|
|||
Review critically why the plans have not been fulfilled or the objective achieved. Justify your assessment with detailed evidence drawn from the objective, observations, and actions taken. Itemize the assessment using this format: `- plan [{plan_id}]\n\t[{step_ids_taken_for_this_milestone}] [{concrete_proof_from_observation}] [{why_milestone_a_not_successful}]\n\t[{step_ids_taken_for_this_milestone}] [{concrete_proof_from_observation}] [{why_milestone_b_not_successful}]\n\t...`.
|
1
AgentOccam/prompts/output_specifications/reason.txt
Normal file
1
AgentOccam/prompts/output_specifications/reason.txt
Normal file
|
@ -0,0 +1 @@
|
|||
Provide your rationale for proposing the subsequent action commands here.
|
1
AgentOccam/prompts/planning_specifications/branch.txt
Normal file
1
AgentOccam/prompts/planning_specifications/branch.txt
Normal file
|
@ -0,0 +1 @@
|
|||
branch [parent_plan_id] [new_subplan_intent]: To create a new subplan based on PREVIOUS PLANS. Ensure the new subplan is connected to the appropriate parent plan by using its ID. E.g., `branch [12] [Navigate to the "Issue" page to check all the issues.]`
|
1
AgentOccam/prompts/planning_specifications/prune.txt
Normal file
1
AgentOccam/prompts/planning_specifications/prune.txt
Normal file
|
@ -0,0 +1 @@
|
|||
prune [resume_plan_id] [reason]: To return to a previous plan state when the current plan is deemed impractical. Enter the ID of the plan state you want to resume. E.g., `prune [5] [The current page lacks items "black speaker," prompting a return to the initial page to restart the item search.]`
|
401
AgentOccam/utils.py
Normal file
401
AgentOccam/utils.py
Normal file
File diff suppressed because one or more lines are too long
26
Agent_E/ae/config.py
Normal file
26
Agent_E/ae/config.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# config.py at the project source code root
|
||||
import os
|
||||
|
||||
PROJECT_SOURCE_ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
SOURCE_LOG_FOLDER_PATH = os.path.join(PROJECT_SOURCE_ROOT, 'log_files')
|
||||
|
||||
PROJECT_ROOT = os.path.dirname(PROJECT_SOURCE_ROOT)
|
||||
|
||||
PROJECT_TEMP_PATH = os.path.join(PROJECT_ROOT, 'temp')
|
||||
|
||||
USER_PREFERENCES_PATH = os.path.join(PROJECT_SOURCE_ROOT, 'user_preferences')
|
||||
PROJECT_TEST_ROOT = os.path.join(PROJECT_ROOT, 'test')
|
||||
|
||||
# Check if the log folder exists, and if not, create it
|
||||
if not os.path.exists(SOURCE_LOG_FOLDER_PATH):
|
||||
os.makedirs(SOURCE_LOG_FOLDER_PATH)
|
||||
print(f"Created log folder at: {SOURCE_LOG_FOLDER_PATH}")
|
||||
|
||||
#create user prefernces folder if it does not exist
|
||||
if not os.path.exists(USER_PREFERENCES_PATH):
|
||||
os.makedirs(USER_PREFERENCES_PATH)
|
||||
print(f"Created user preferences folder at: {USER_PREFERENCES_PATH}")
|
||||
|
||||
if not os.path.exists(PROJECT_TEMP_PATH):
|
||||
os.makedirs(PROJECT_TEMP_PATH)
|
||||
print(f"Created temp folder at: {PROJECT_TEMP_PATH}")
|
9
Agent_E/ae/core/__init__.py
Normal file
9
Agent_E/ae/core/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from Agent_E.ae.core import agents
|
||||
from Agent_E.ae.core import memory
|
||||
from Agent_E.ae.core import skills
|
||||
from Agent_E.ae.core.autogen_wrapper import AutogenWrapper
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.core.post_process_responses import final_reply_callback_user_proxy
|
||||
from Agent_E.ae.core.prompts import LLM_PROMPTS
|
||||
from Agent_E.ae.core.system_orchestrator import SystemOrchestrator
|
||||
from Agent_E.ae.core.ui_manager import UIManager
|
1
Agent_E/ae/core/agents/__init__.py
Normal file
1
Agent_E/ae/core/agents/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from Agent_E.ae.core.agents.browser_nav_agent import BrowserNavAgent
|
164
Agent_E/ae/core/agents/browser_nav_agent.py
Normal file
164
Agent_E/ae/core/agents/browser_nav_agent.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
import importlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
import autogen # type: ignore
|
||||
|
||||
from Agent_E.ae.core.memory.static_ltm import get_user_ltm
|
||||
from Agent_E.ae.core.prompts import LLM_PROMPTS
|
||||
from Agent_E.ae.core.skills.click_using_selector import click as click_element
|
||||
|
||||
# from Agent_E.ae.core.skills.enter_text_and_click import enter_text_and_click
|
||||
from Agent_E.ae.core.skills.enter_text_using_selector import bulk_enter_text
|
||||
from Agent_E.ae.core.skills.enter_text_using_selector import entertext
|
||||
from Agent_E.ae.core.skills.get_dom_with_content_type import get_dom_with_content_type
|
||||
from Agent_E.ae.core.skills.get_url import geturl
|
||||
from Agent_E.ae.core.skills.open_url import openurl
|
||||
from Agent_E.ae.core.skills.pdf_text_extractor import extract_text_from_pdf
|
||||
|
||||
#from Agent_E.ae.core.skills.pdf_text_extractor import extract_text_from_pdf
|
||||
from Agent_E.ae.core.skills.press_key_combination import press_key_combination
|
||||
from Agent_E.ae.core.skills.skill_registry import skill_registry
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
class BrowserNavAgent:
|
||||
def __init__(self, model_config_list, llm_config_params: dict[str, Any], system_prompt: str|None, browser_nav_executor: autogen.UserProxyAgent): # type: ignore
|
||||
"""
|
||||
Initialize the BrowserNavAgent and store the AssistantAgent instance
|
||||
as an instance attribute for external access.
|
||||
|
||||
Parameters:
|
||||
- model_config_list: A list of configuration parameters required for AssistantAgent.
|
||||
- llm_config_params: A dictionary of configuration parameters for the LLM.
|
||||
- system_prompt: The system prompt to be used for this agent or the default will be used if not provided.
|
||||
- user_proxy_agent: An instance of the UserProxyAgent class.
|
||||
"""
|
||||
self.browser_nav_executor = browser_nav_executor
|
||||
user_ltm = self.__get_ltm()
|
||||
|
||||
system_message = LLM_PROMPTS["BROWSER_AGENT_PROMPT"]
|
||||
if system_prompt and len(system_prompt) > 0:
|
||||
if isinstance(system_prompt, list):
|
||||
system_message = "\n".join(system_prompt)
|
||||
else:
|
||||
system_message = system_prompt
|
||||
logger.info(f"Using custom system prompt for BrowserNavAgent: {system_message}")
|
||||
|
||||
system_message = system_message + "\n" + f"Today's date is {datetime.now().strftime('%d %B %Y')}"
|
||||
if user_ltm: #add the user LTM to the system prompt if it exists
|
||||
user_ltm = "\n" + user_ltm
|
||||
system_message = Template(system_message).substitute(basic_user_information=user_ltm)
|
||||
logger.info(f"Browser nav agent using model: {model_config_list[0]['model']}")
|
||||
self.agent = autogen.ConversableAgent(
|
||||
name="browser_navigation_agent",
|
||||
system_message=system_message,
|
||||
llm_config={
|
||||
"config_list": model_config_list,
|
||||
**llm_config_params #unpack all the name value pairs in llm_config_params as is
|
||||
},
|
||||
)
|
||||
self.__register_skills()
|
||||
|
||||
|
||||
def __get_ltm(self):
|
||||
"""
|
||||
Get the the long term memory of the user.
|
||||
returns: str | None - The user LTM or None if not found.
|
||||
"""
|
||||
return get_user_ltm()
|
||||
|
||||
|
||||
def __register_skills(self):
|
||||
"""
|
||||
Register all the skills that the agent can perform.
|
||||
"""
|
||||
|
||||
# Register each skill for LLM by assistant agent and for execution by user_proxy_agen
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["OPEN_URL_PROMPT"])(openurl)
|
||||
self.browser_nav_executor.register_for_execution()(openurl)
|
||||
|
||||
# self.agent.register_for_llm(description=LLM_PROMPTS["ENTER_TEXT_AND_CLICK_PROMPT"])(enter_text_and_click)
|
||||
# self.browser_nav_executor.register_for_execution()(enter_text_and_click)
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["GET_DOM_WITH_CONTENT_TYPE_PROMPT"])(get_dom_with_content_type)
|
||||
self.browser_nav_executor.register_for_execution()(get_dom_with_content_type)
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["CLICK_PROMPT"])(click_element)
|
||||
self.browser_nav_executor.register_for_execution()(click_element)
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["GET_URL_PROMPT"])(geturl)
|
||||
self.browser_nav_executor.register_for_execution()(geturl)
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["BULK_ENTER_TEXT_PROMPT"])(bulk_enter_text)
|
||||
self.browser_nav_executor.register_for_execution()(bulk_enter_text)
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["ENTER_TEXT_PROMPT"])(entertext)
|
||||
self.browser_nav_executor.register_for_execution()(entertext)
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["PRESS_KEY_COMBINATION_PROMPT"])(press_key_combination)
|
||||
self.browser_nav_executor.register_for_execution()(press_key_combination)
|
||||
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["EXTRACT_TEXT_FROM_PDF_PROMPT"])(extract_text_from_pdf)
|
||||
self.browser_nav_executor.register_for_execution()(extract_text_from_pdf)
|
||||
|
||||
'''
|
||||
# Register reply function for printing messages
|
||||
self.browser_nav_executor.register_reply( # type: ignore
|
||||
[autogen.Agent, None],
|
||||
reply_func=print_message_from_user_proxy,
|
||||
config={"callback": None},
|
||||
)
|
||||
self.agent.register_reply( # type: ignore
|
||||
[autogen.Agent, None],
|
||||
reply_func=print_message_from_browser_agent,
|
||||
config={"callback": None},
|
||||
)
|
||||
'''
|
||||
self.__load_additional_skills()
|
||||
|
||||
#print(f">>> Function map: {self.browser_nav_executor.function_map}") # type: ignore
|
||||
|
||||
|
||||
def __load_additional_skills(self):
|
||||
"""
|
||||
Dynamically load additional skills from directories or specific Python files
|
||||
specified by an environment variable.
|
||||
"""
|
||||
# Get additional skill directories or files from environment variable
|
||||
additional_skill_dirs: str = os.getenv('ADDITIONAL_SKILL_DIRS', "")
|
||||
if len(additional_skill_dirs) == 0:
|
||||
logger.debug("No additional skill directories or files specified.")
|
||||
return
|
||||
|
||||
additional_skill_paths: list[str] = additional_skill_dirs.split(',')
|
||||
|
||||
for skill_path in additional_skill_paths:
|
||||
skill_path = skill_path.strip() # Strip whitespace
|
||||
|
||||
if os.path.isdir(skill_path):
|
||||
# If the path is a directory, process all .py files in it
|
||||
for filename in os.listdir(skill_path):
|
||||
if filename.endswith(".py"):
|
||||
module_name = filename[:-3] # Remove .py extension
|
||||
module_path = f"{skill_path.replace('/', '.')}.{module_name}"
|
||||
importlib.import_module(module_path)
|
||||
|
||||
elif skill_path.endswith(".py") and os.path.isfile(skill_path):
|
||||
# If the path is a specific .py file, load it directly
|
||||
module_name = os.path.basename(skill_path)[:-3] # Strip .py extension
|
||||
directory_path = os.path.dirname(skill_path).replace('/', '.')
|
||||
module_path = f"{directory_path}.{module_name}"
|
||||
importlib.import_module(module_path)
|
||||
else:
|
||||
logger.warning(f"Invalid skill path specified: {skill_path}")
|
||||
|
||||
# Register the skills that were dynamically discovered
|
||||
for skill in skill_registry:
|
||||
self.agent.register_for_llm(description=skill['description'])(skill['func'])
|
||||
self.browser_nav_executor.register_for_execution()(skill['func'])
|
||||
logger.debug(f"Registered additional skill: {skill['name']}")
|
||||
|
77
Agent_E/ae/core/agents/high_level_planner_agent.py
Normal file
77
Agent_E/ae/core/agents/high_level_planner_agent.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
from datetime import datetime
|
||||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
import autogen # type: ignore
|
||||
from autogen import ConversableAgent # type: ignore
|
||||
|
||||
from Agent_E.ae.core.memory.static_ltm import get_user_ltm
|
||||
from Agent_E.ae.core.post_process_responses import final_reply_callback_planner_agent as print_message_as_planner # type: ignore
|
||||
from Agent_E.ae.core.prompts import LLM_PROMPTS
|
||||
from Agent_E.ae.core.skills.get_user_input import get_user_input
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
class PlannerAgent:
|
||||
def __init__(self, model_config_list, llm_config_params: dict[str, Any], system_prompt: str|None, user_proxy_agent:ConversableAgent): # type: ignore
|
||||
"""
|
||||
Initialize the PlannerAgent and store the AssistantAgent instance
|
||||
as an instance attribute for external access.
|
||||
|
||||
Parameters:
|
||||
- model_config_list: A list of configuration parameters required for AssistantAgent.
|
||||
- llm_config_params: A dictionary of configuration parameters for the LLM.
|
||||
- system_prompt: The system prompt to be used for this agent or the default will be used if not provided.
|
||||
- user_proxy_agent: An instance of the UserProxyAgent class.
|
||||
"""
|
||||
enable_user_input = os.getenv("PLANNER_USER_INPUT_SKILL_ENABLED", "false").lower() == "true"
|
||||
|
||||
user_ltm = self.__get_ltm()
|
||||
system_message = LLM_PROMPTS["PLANNER_AGENT_PROMPT"]
|
||||
|
||||
if system_prompt and len(system_prompt) > 0:
|
||||
if isinstance(system_prompt, list):
|
||||
system_message = "\n".join(system_prompt)
|
||||
else:
|
||||
system_message = system_prompt
|
||||
logger.info(f"Using custom system prompt for PlannerAgent: {system_message}")
|
||||
|
||||
|
||||
if user_ltm: #add the user LTM to the system prompt if it exists
|
||||
user_ltm = "\n" + user_ltm
|
||||
system_message = Template(system_message).substitute(basic_user_information=user_ltm)
|
||||
system_message = system_message + "\n" + f"Today's date is {datetime.now().strftime('%d %B %Y')}"
|
||||
logger.info(f"Planner agent using model: {model_config_list[0]['model']}")
|
||||
|
||||
self.agent = autogen.AssistantAgent(
|
||||
name="planner_agent",
|
||||
system_message=system_message,
|
||||
llm_config={
|
||||
"config_list": model_config_list,
|
||||
**llm_config_params #unpack all the name value pairs in llm_config_params as is
|
||||
},
|
||||
)
|
||||
|
||||
if enable_user_input:
|
||||
# Register get_user_input skill for LLM by assistant agent
|
||||
self.agent.register_for_llm(description=LLM_PROMPTS["GET_USER_INPUT_PROMPT"])(get_user_input)
|
||||
# Register get_user_input skill for execution by user_proxy_agent
|
||||
user_proxy_agent.register_for_execution()(get_user_input)
|
||||
else:
|
||||
logger.debug("User input skill is disabled for PlannerAgent")
|
||||
|
||||
self.agent.register_reply( # type: ignore
|
||||
[autogen.AssistantAgent, None],
|
||||
reply_func=print_message_as_planner,
|
||||
config={"callback": None},
|
||||
ignore_async_in_sync_chat=True
|
||||
)
|
||||
|
||||
def __get_ltm(self):
|
||||
"""
|
||||
Get the the long term memory of the user.
|
||||
returns: str | None - The user LTM or None if not found.
|
||||
"""
|
||||
return get_user_ltm()
|
||||
|
197
Agent_E/ae/core/agents_llm_config.py
Normal file
197
Agent_E/ae/core/agents_llm_config.py
Normal file
|
@ -0,0 +1,197 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
class AgentsLLMConfig:
|
||||
# Mapping from environment keys to model config keys
|
||||
KEY_MAPPING_ENV_MODEL: dict[str, str] = {
|
||||
"AUTOGEN_MODEL_NAME": "model",
|
||||
"AUTOGEN_MODEL_API_KEY": "api_key",
|
||||
"AUTOGEN_MODEL_BASE_URL": "base_url",
|
||||
"AUTOGEN_MODEL_API_TYPE": "api_type",
|
||||
"AUTOGEN_MODEL_API_VERSION": "api_version",
|
||||
}
|
||||
|
||||
# Mapping from environment keys to LLM config keys
|
||||
KEY_MAPPING_ENV_LLM: dict[str, str] = {
|
||||
"AUTOGEN_LLM_TEMPERATURE": "temperature",
|
||||
"AUTOGEN_LLM_TOP_P": "top_p",
|
||||
}
|
||||
|
||||
# Mapping from file keys to model config keys
|
||||
KEY_MAPPING_FILE: dict[str, str] = {
|
||||
"model_name": "model",
|
||||
"model_api_key": "api_key",
|
||||
"model_base_url": "base_url",
|
||||
"model_api_type": "api_type",
|
||||
}
|
||||
|
||||
def __init__(self, env_file_path: str = ".env", llm_config: dict[str,Any] | None = None) -> None:
|
||||
load_dotenv(env_file_path, verbose=True, override=True)
|
||||
if llm_config:
|
||||
self.config: dict[str, Any] = self.load_config_from_api(llm_config)
|
||||
else:
|
||||
self.config: dict[str, Any] = self._load_config()
|
||||
|
||||
|
||||
def _load_config(self) -> dict[str, Any]:
|
||||
config_file = os.getenv("AGENTS_LLM_CONFIG_FILE")
|
||||
config_file_ref_key = os.getenv("AGENTS_LLM_CONFIG_FILE_REF_KEY")
|
||||
|
||||
if config_file:
|
||||
try:
|
||||
with open(config_file, 'r') as file: # noqa: UP015
|
||||
file_config = json.load(file)
|
||||
|
||||
if config_file_ref_key:
|
||||
if config_file_ref_key in file_config:
|
||||
logger.info(f"Loading configuration from: {config_file} with key: {config_file_ref_key}")
|
||||
raw_config = file_config[config_file_ref_key]
|
||||
|
||||
# Process configurations for both planner_agent and browser_nav_agent
|
||||
planner_config = self._normalize_config(raw_config.get("planner_agent", {}))
|
||||
browser_nav_config = self._normalize_config(raw_config.get("browser_nav_agent", {}))
|
||||
|
||||
config = {
|
||||
"planner_agent": planner_config,
|
||||
"browser_nav_agent": browser_nav_config,
|
||||
"other_settings": {k: v for k, v in raw_config.items() if k not in ["planner_agent", "browser_nav_agent"]},
|
||||
}
|
||||
logger.info(f"Using configuration key '{config_file_ref_key}' from the config file.")
|
||||
else:
|
||||
logger.error(f"Key '{config_file_ref_key}' not found in the configuration file.")
|
||||
raise KeyError(f"Key '{config_file_ref_key}' not found in the configuration file.")
|
||||
else:
|
||||
logger.error("AGENTS_LLM_CONFIG_FILE_REF_KEY is not provided.")
|
||||
raise ValueError("AGENTS_LLM_CONFIG_FILE_REF_KEY must be provided if AGENTS_LLM_CONFIG_FILE is set.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading configuration file: {e}")
|
||||
raise e
|
||||
else:
|
||||
logger.info("Loading configuration from environment variables")
|
||||
# Load configurations from environment variables
|
||||
normalized_config = self._normalize_config_from_env()
|
||||
|
||||
config = {
|
||||
"planner_agent": normalized_config,
|
||||
"browser_nav_agent": normalized_config
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
def load_config_from_api(self, llm_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Load configuration from a JSON provided during execution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_string : dict[str,Any]
|
||||
A JSON representing the configuration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]
|
||||
The loaded and normalized configuration.
|
||||
"""
|
||||
try:
|
||||
|
||||
logger.info("Loading LLM configuration provided via API.")
|
||||
|
||||
# Process configurations for both planner_agent and browser_nav_agent
|
||||
planner_config = self._normalize_config(llm_config.get("planner_agent", {}))
|
||||
browser_nav_config = self._normalize_config(llm_config.get("browser_nav_agent", {}))
|
||||
|
||||
config = {
|
||||
"planner_agent": planner_config,
|
||||
"browser_nav_agent": browser_nav_config,
|
||||
"other_settings": {k: v for k, v in llm_config.items() if k not in ["planner_agent", "browser_nav_agent"]},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error decoding JSON string: {e}")
|
||||
raise e
|
||||
|
||||
def _normalize_config(self, agent_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalize agent-specific config from a file, grouping keys into model_config_params, llm_config_params, and other_settings."""
|
||||
model_config = {}
|
||||
llm_config_params = {}
|
||||
other_settings = {}
|
||||
|
||||
for k, v in agent_config.items():
|
||||
if k in self.KEY_MAPPING_FILE:
|
||||
model_config[self.KEY_MAPPING_FILE[k]] = v
|
||||
elif k == "llm_config_params":
|
||||
llm_config_params = v # Keep llm_config_params as is
|
||||
else:
|
||||
other_settings[k] = v
|
||||
|
||||
return {
|
||||
"model_config_params": model_config,
|
||||
"llm_config_params": llm_config_params,
|
||||
"other_settings": other_settings,
|
||||
}
|
||||
|
||||
def _normalize_config_from_env(self) -> dict[str, Any]:
|
||||
"""Normalize config from environment variables, adding defaults for 'temperature', 'top_p', and 'seed' based on model name."""
|
||||
model_config = {}
|
||||
llm_config_params = {}
|
||||
other_settings = {}
|
||||
|
||||
# Populate model_config_params
|
||||
for original_key, mapped_key in self.KEY_MAPPING_ENV_MODEL.items():
|
||||
value = os.getenv(original_key)
|
||||
if value is not None:
|
||||
model_config[mapped_key] = value
|
||||
|
||||
# Populate llm_config_params
|
||||
for original_key, mapped_key in self.KEY_MAPPING_ENV_LLM.items():
|
||||
value = os.getenv(original_key)
|
||||
if value is not None:
|
||||
llm_config_params[mapped_key] = value
|
||||
|
||||
# Capture other settings that start with 'AUTOGEN_MODEL'
|
||||
for original_key in os.environ:
|
||||
if original_key.startswith("AUTOGEN_MODEL") and original_key not in self.KEY_MAPPING_ENV_MODEL:
|
||||
other_settings[original_key] = os.getenv(original_key)
|
||||
|
||||
# Apply defaults for 'temperature', 'top_p', 'seed' if not present
|
||||
model_name:str = model_config.get("model", "").lower() # type: ignore
|
||||
|
||||
if model_name.startswith("gpt"): # type: ignore
|
||||
llm_config_params.setdefault("temperature", 0.0) # type: ignore
|
||||
llm_config_params.setdefault("top_p", 0.001) # type: ignore
|
||||
llm_config_params.setdefault("seed", 12345) # type: ignore
|
||||
else:
|
||||
llm_config_params.setdefault("temperature", 0.1) # type: ignore
|
||||
llm_config_params.setdefault("top_p", 0.1) # type: ignore
|
||||
|
||||
return {
|
||||
"model_config_params": model_config,
|
||||
"llm_config_params": llm_config_params,
|
||||
"other_settings": other_settings,
|
||||
}
|
||||
|
||||
def get_planner_agent_config(self) -> dict[str, Any]:
|
||||
return self.config["planner_agent"]
|
||||
|
||||
def get_browser_nav_agent_config(self) -> dict[str, Any]:
|
||||
return self.config["browser_nav_agent"]
|
||||
|
||||
def get_full_config(self) -> dict[str, Any]:
|
||||
return self.config
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
config = AgentsLLMConfig()
|
||||
|
||||
planner_config = config.get_planner_agent_config()
|
||||
browser_nav_config = config.get_browser_nav_agent_config()
|
384
Agent_E/ae/core/autogen_wrapper.py
Normal file
384
Agent_E/ae/core/autogen_wrapper.py
Normal file
|
@ -0,0 +1,384 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
from string import Template
|
||||
from time import time_ns
|
||||
from typing import Any
|
||||
|
||||
import autogen # type: ignore
|
||||
import nest_asyncio # type: ignore
|
||||
import openai
|
||||
|
||||
#from autogen import Cache
|
||||
from Agent_E.ae.config import SOURCE_LOG_FOLDER_PATH
|
||||
from Agent_E.ae.core.agents.browser_nav_agent import BrowserNavAgent
|
||||
from Agent_E.ae.core.agents.high_level_planner_agent import PlannerAgent
|
||||
from Agent_E.ae.core.post_process_responses import final_reply_callback_planner_agent as notify_planner_messages # type: ignore
|
||||
from Agent_E.ae.core.prompts import LLM_PROMPTS
|
||||
from Agent_E.ae.core.skills.get_url import geturl
|
||||
from Agent_E.ae.utils.autogen_sequential_function_call import UserProxyAgent_SequentialFunctionExecution
|
||||
from Agent_E.ae.utils.detect_llm_loops import is_agent_stuck_in_loop
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.response_parser import parse_response
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
nest_asyncio.apply() # type: ignore
|
||||
|
||||
class AutogenWrapper:
|
||||
"""
|
||||
A wrapper class for interacting with the Autogen library.
|
||||
|
||||
Args:
|
||||
planner_max_chat_round (int): The maximum number of chat rounds for the planner agent.
|
||||
browser_nav_max_chat_round (int): The maximum number of chat rounds for the browser navigation agent.
|
||||
|
||||
Attributes:
|
||||
number_of_rounds (int): The maximum number of chat rounds.
|
||||
agents_map (dict): A dictionary of the agents that are instantiated in this autogen instance.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, save_chat_logs_to_files: bool = True, planner_max_chat_round: int = 50, browser_nav_max_chat_round: int = 10):
|
||||
self.planner_number_of_rounds = planner_max_chat_round
|
||||
self.browser_number_of_rounds = browser_nav_max_chat_round
|
||||
|
||||
self.agents_map: dict[str, UserProxyAgent_SequentialFunctionExecution | autogen.AssistantAgent | autogen.ConversableAgent ] | None = None
|
||||
|
||||
self.planner_agent_model_config : list[dict[str, str]] | None = None
|
||||
self.browser_nav_agent_model_config : list[dict[str, str]] | None = None
|
||||
|
||||
self.planner_agent_config: dict[str, Any] | None = None
|
||||
self.browser_nav_agent_config: dict[str, Any] | None = None
|
||||
|
||||
self.chat_logs_dir: str = SOURCE_LOG_FOLDER_PATH
|
||||
self.save_chat_logs_to_files = save_chat_logs_to_files
|
||||
|
||||
@classmethod
|
||||
async def create(cls, planner_agent_config: dict[str, Any], browser_nav_agent_config: dict[str, Any], agents_needed: list[str] | None = None,
|
||||
save_chat_logs_to_files: bool = True, planner_max_chat_round: int = 50, browser_nav_max_chat_round: int = 10):
|
||||
"""
|
||||
Create an instance of AutogenWrapper.
|
||||
|
||||
Args:
|
||||
planner_agent_config: dict[str, Any]: A dictionary containing the configuration parameters for the planner agent. For example:
|
||||
{
|
||||
"model_name": "gpt-4o",
|
||||
"model_api_key": "",
|
||||
"model_base_url": null,
|
||||
"system_prompt": ["optional prompt unless you want to use the built in"],
|
||||
"llm_config_params": { #all name value pairs here will go to the llm config of autogen verbatim
|
||||
"cache_seed": null,
|
||||
"temperature": 0.001,
|
||||
"top_p": 0.001
|
||||
}
|
||||
}
|
||||
browser_nav_agent_config: dict[str, Any]: A dictionary containing the configuration parameters for the browser navigation agent. Same format as planner_agent_config.
|
||||
agents_needed (list[str], optional): The list of agents needed. If None, then ["user", "browser_nav_executor", "planner_agent", "browser_nav_agent"] will be used.
|
||||
save_chat_logs_to_files (bool, optional): Whether to save chat logs to files. Defaults to True.
|
||||
planner_max_chat_rounds (int, optional): The maximum number of chat rounds for the planner. Defaults to 50.
|
||||
browser_nav_max_chat_round (int, optional): The maximum number of chat rounds for the browser navigation agent. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
AutogenWrapper: An instance of AutogenWrapper.
|
||||
|
||||
"""
|
||||
print(f">>> Creating AutogenWrapper with {agents_needed}, Planner max chat rounds: {planner_max_chat_round}, browser nav max chat rounds: {browser_nav_max_chat_round}. Save chat logs to files: {save_chat_logs_to_files}")
|
||||
if agents_needed is None:
|
||||
agents_needed = ["user", "browser_nav_executor", "planner_agent", "browser_nav_agent"]
|
||||
# Create an instance of cls
|
||||
self = cls(save_chat_logs_to_files=save_chat_logs_to_files, planner_max_chat_round=planner_max_chat_round, browser_nav_max_chat_round=browser_nav_max_chat_round)
|
||||
|
||||
os.environ["AUTOGEN_USE_DOCKER"] = "False"
|
||||
|
||||
self.planner_agent_config = planner_agent_config
|
||||
self.browser_nav_agent_config = browser_nav_agent_config
|
||||
|
||||
self.planner_agent_model_config = self.convert_model_config_to_autogen_format(self.planner_agent_config["model_config_params"])
|
||||
self.browser_nav_agent_model_config = self.convert_model_config_to_autogen_format(self.browser_nav_agent_config["model_config_params"])
|
||||
|
||||
self.agents_map = await self.__initialize_agents(agents_needed)
|
||||
|
||||
def trigger_nested_chat(manager: autogen.ConversableAgent):
|
||||
content:str=manager.last_message()["content"] # type: ignore
|
||||
content_json = parse_response(content) # type: ignore
|
||||
next_step = content_json.get('next_step', None)
|
||||
plan = content_json.get('plan', None)
|
||||
if plan is not None:
|
||||
notify_planner_messages(plan, message_type=MessageType.PLAN)
|
||||
|
||||
if next_step is None:
|
||||
notify_planner_messages("Received no response, terminating..", message_type=MessageType.INFO) # type: ignore
|
||||
return False
|
||||
else:
|
||||
notify_planner_messages(next_step, message_type=MessageType.STEP) # type: ignore
|
||||
return True
|
||||
|
||||
def get_url() -> str:
|
||||
return asyncio.run(geturl())
|
||||
|
||||
def my_custom_summary_method(sender: autogen.ConversableAgent,recipient: autogen.ConversableAgent, summary_args: dict ) : # type: ignore
|
||||
messages_str_keys = {str(key): value for key, value in sender.chat_messages.items()} # type: ignore
|
||||
self.__save_chat_log(list(messages_str_keys.values())[0]) # type: ignore
|
||||
last_message=recipient.last_message(sender)["content"] # type: ignore
|
||||
if not last_message or last_message.strip() == "": # type: ignore
|
||||
# print(f">>> Last message from browser nav was empty. Max turns: {self.browser_number_of_rounds*2}, number of messages: {len(list(sender.chat_messages.items())[0][1])}")
|
||||
# print(">>> Sender messages:", json.dumps( list(sender.chat_messages.items())[0][1], indent=2))
|
||||
return "I received an empty message. This is not an error and is recoverable. Try to reformulate the task..."
|
||||
elif "##TERMINATE TASK##" in last_message:
|
||||
last_message=last_message.replace("##TERMINATE TASK##", "") # type: ignore
|
||||
last_message=last_message+" "+ get_url() # type: ignore
|
||||
notify_planner_messages(last_message, message_type=MessageType.ACTION) # type: ignore
|
||||
return last_message # type: ignore
|
||||
return recipient.last_message(sender)["content"] # type: ignore
|
||||
|
||||
def reflection_message(recipient, messages, sender, config): # type: ignore
|
||||
last_message=messages[-1]["content"] # type: ignore
|
||||
content_json = parse_response(last_message) # type: ignore
|
||||
next_step = content_json.get('next_step', None)
|
||||
|
||||
if next_step is None:
|
||||
print ("Message to nested chat returned None")
|
||||
return None
|
||||
else:
|
||||
next_step = next_step.strip() +" " + get_url() # type: ignore
|
||||
return next_step # type: ignore
|
||||
|
||||
# print(f">>> Registering nested chat. Available agents: {self.agents_map}")
|
||||
self.agents_map["user"].register_nested_chats( # type: ignore
|
||||
[
|
||||
{
|
||||
"sender": self.agents_map["browser_nav_executor"],
|
||||
"recipient": self.agents_map["browser_nav_agent"],
|
||||
"message":reflection_message,
|
||||
"max_turns": self.browser_number_of_rounds,
|
||||
"summary_method": my_custom_summary_method,
|
||||
}
|
||||
],
|
||||
trigger=trigger_nested_chat, # type: ignore
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def convert_model_config_to_autogen_format(self, model_config: dict[str, str]) -> list[dict[str, Any]]:
|
||||
env_var: list[dict[str, str]] = [model_config]
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode='w') as temp:
|
||||
json.dump(env_var, temp)
|
||||
temp_file_path = temp.name
|
||||
|
||||
return autogen.config_list_from_json(env_or_file=temp_file_path)
|
||||
|
||||
def get_chat_logs_dir(self) -> str|None:
|
||||
"""
|
||||
Get the directory for saving chat logs.
|
||||
|
||||
Returns:
|
||||
str|None: The directory path or None if there is not one
|
||||
|
||||
"""
|
||||
return self.chat_logs_dir
|
||||
|
||||
def set_chat_logs_dir(self, chat_logs_dir: str):
|
||||
"""
|
||||
Set the directory for saving chat logs.
|
||||
|
||||
Args:
|
||||
chat_logs_dir (str): The directory path.
|
||||
|
||||
"""
|
||||
self.chat_logs_dir = chat_logs_dir
|
||||
|
||||
|
||||
def __save_chat_log(self, chat_log: list[dict[str, Any]]):
|
||||
if not self.save_chat_logs_to_files:
|
||||
logger.info("Nested chat logs", extra={"nested_chat_log": chat_log})
|
||||
else:
|
||||
chat_logs_file = os.path.join(self.get_chat_logs_dir() or "", f"nested_chat_log_{str(time_ns())}.json")
|
||||
# Save the chat log to a file
|
||||
with open(chat_logs_file, "w") as file:
|
||||
json.dump(chat_log, file, indent=4)
|
||||
|
||||
|
||||
async def __initialize_agents(self, agents_needed: list[str]):
|
||||
"""
|
||||
Instantiate all agents with their appropriate prompts/skills.
|
||||
|
||||
Args:
|
||||
agents_needed (list[str]): The list of agents needed, this list must have user_proxy in it or an error will be generated.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of agent instances.
|
||||
|
||||
"""
|
||||
agents_map: dict[str, UserProxyAgent_SequentialFunctionExecution | autogen.ConversableAgent]= {}
|
||||
|
||||
user_delegate_agent = await self.__create_user_delegate_agent()
|
||||
agents_map["user"] = user_delegate_agent
|
||||
agents_needed.remove("user")
|
||||
|
||||
browser_nav_executor = self.__create_browser_nav_executor_agent()
|
||||
agents_map["browser_nav_executor"] = browser_nav_executor
|
||||
agents_needed.remove("browser_nav_executor")
|
||||
|
||||
for agent_needed in agents_needed:
|
||||
if agent_needed == "browser_nav_agent":
|
||||
browser_nav_agent: autogen.ConversableAgent = self.__create_browser_nav_agent(agents_map["browser_nav_executor"] )
|
||||
agents_map["browser_nav_agent"] = browser_nav_agent
|
||||
elif agent_needed == "planner_agent":
|
||||
planner_agent = self.__create_planner_agent(user_delegate_agent)
|
||||
agents_map["planner_agent"] = planner_agent
|
||||
else:
|
||||
raise ValueError(f"Unknown agent type: {agent_needed}")
|
||||
return agents_map
|
||||
|
||||
|
||||
async def __create_user_delegate_agent(self) -> autogen.ConversableAgent:
|
||||
"""
|
||||
Create a ConversableAgent instance.
|
||||
|
||||
Returns:
|
||||
autogen.ConversableAgent: An instance of ConversableAgent.
|
||||
|
||||
"""
|
||||
def is_planner_termination_message(x: dict[str, str])->bool: # type: ignore
|
||||
should_terminate = False
|
||||
function: Any = x.get("function", None)
|
||||
if function is not None:
|
||||
return False
|
||||
|
||||
content:Any = x.get("content", "")
|
||||
if content is None:
|
||||
content = ""
|
||||
should_terminate = True
|
||||
else:
|
||||
try:
|
||||
content_json = parse_response(content)
|
||||
_terminate = content_json.get('terminate', "no")
|
||||
final_response = content_json.get('final_response', None)
|
||||
if(_terminate == "yes"):
|
||||
should_terminate = True
|
||||
if final_response:
|
||||
notify_planner_messages(final_response, message_type=MessageType.ANSWER)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding JSON response:\n{content}.\nTerminating..")
|
||||
should_terminate = True
|
||||
|
||||
return should_terminate # type: ignore
|
||||
|
||||
task_delegate_agent = UserProxyAgent_SequentialFunctionExecution(
|
||||
name="user",
|
||||
llm_config=False,
|
||||
system_message=LLM_PROMPTS["USER_AGENT_PROMPT"],
|
||||
is_termination_msg=is_planner_termination_message, # type: ignore
|
||||
human_input_mode="NEVER",
|
||||
max_consecutive_auto_reply=self.planner_number_of_rounds,
|
||||
)
|
||||
return task_delegate_agent
|
||||
|
||||
def __create_browser_nav_executor_agent(self):
|
||||
"""
|
||||
Create a UserProxyAgent instance for executing browser control.
|
||||
|
||||
Returns:
|
||||
autogen.UserProxyAgent: An instance of UserProxyAgent.
|
||||
|
||||
"""
|
||||
def is_browser_executor_termination_message(x: dict[str, str])->bool: # type: ignore
|
||||
|
||||
tools_call:Any = x.get("tool_calls", "")
|
||||
if tools_call :
|
||||
chat_messages=self.agents_map["browser_nav_executor"].chat_messages #type: ignore
|
||||
# Get the only key from the dictionary
|
||||
agent_key = next(iter(chat_messages)) # type: ignore
|
||||
# Get the chat messages corresponding to the only key
|
||||
messages = chat_messages[agent_key] # type: ignore
|
||||
return is_agent_stuck_in_loop(messages) # type: ignore
|
||||
else:
|
||||
print("Terminating browser executor")
|
||||
return True
|
||||
|
||||
browser_nav_executor_agent = UserProxyAgent_SequentialFunctionExecution(
|
||||
name="browser_nav_executor",
|
||||
is_termination_msg=is_browser_executor_termination_message,
|
||||
human_input_mode="NEVER",
|
||||
llm_config=None,
|
||||
max_consecutive_auto_reply=self.browser_number_of_rounds,
|
||||
code_execution_config={
|
||||
"last_n_messages": 1,
|
||||
"work_dir": "tasks",
|
||||
"use_docker": False,
|
||||
},
|
||||
)
|
||||
print(">>> Created browser_nav_executor_agent:", browser_nav_executor_agent)
|
||||
return browser_nav_executor_agent
|
||||
|
||||
def __create_browser_nav_agent(self, user_proxy_agent: UserProxyAgent_SequentialFunctionExecution) -> autogen.ConversableAgent:
|
||||
"""
|
||||
Create a BrowserNavAgent instance.
|
||||
|
||||
Args:
|
||||
user_proxy_agent (autogen.UserProxyAgent): The instance of UserProxyAgent that was created.
|
||||
|
||||
Returns:
|
||||
autogen.AssistantAgent: An instance of BrowserNavAgent.
|
||||
|
||||
"""
|
||||
browser_nav_agent = BrowserNavAgent(self.browser_nav_agent_model_config, self.browser_nav_agent_config["llm_config_params"], # type: ignore
|
||||
self.browser_nav_agent_config["other_settings"].get("system_prompt", None), user_proxy_agent) # type: ignore
|
||||
#print(">>> browser agent tools:", json.dumps(browser_nav_agent.agent.llm_config.get("tools"), indent=2))
|
||||
return browser_nav_agent.agent
|
||||
|
||||
def __create_planner_agent(self, assistant_agent: autogen.ConversableAgent):
|
||||
"""
|
||||
Create a Planner Agent instance. This is mainly used for exploration at this point
|
||||
|
||||
Returns:
|
||||
autogen.AssistantAgent: An instance of PlannerAgent.
|
||||
|
||||
"""
|
||||
planner_agent = PlannerAgent(self.planner_agent_model_config, self.planner_agent_config["llm_config_params"], # type: ignore
|
||||
self.planner_agent_config["other_settings"].get("system_prompt", None), assistant_agent) # type: ignore
|
||||
return planner_agent.agent
|
||||
|
||||
async def process_command(self, command: str, current_url: str | None = None) -> autogen.ChatResult | None:
|
||||
"""
|
||||
Process a command by sending it to one or more agents.
|
||||
|
||||
Args:
|
||||
command (str): The command to be processed.
|
||||
current_url (str, optional): The current URL of the browser. Defaults to None.
|
||||
|
||||
Returns:
|
||||
autogen.ChatResult | None: The result of the command processing, or None if an error occurred. Contains chat log, cost(tokens/price)
|
||||
|
||||
"""
|
||||
current_url_prompt_segment = ""
|
||||
if current_url:
|
||||
current_url_prompt_segment = f"Current Page: {current_url}"
|
||||
|
||||
prompt = Template(LLM_PROMPTS["COMMAND_EXECUTION_PROMPT"]).substitute(command=command, current_url_prompt_segment=current_url_prompt_segment)
|
||||
logger.info(f"Prompt for command: {prompt}")
|
||||
#with Cache.disk() as cache:
|
||||
try:
|
||||
if self.agents_map is None:
|
||||
raise ValueError("Agents map is not initialized.")
|
||||
|
||||
result=await self.agents_map["user"].a_initiate_chat( # type: ignore
|
||||
self.agents_map["planner_agent"], # self.manager # type: ignore
|
||||
max_turns=self.planner_number_of_rounds,
|
||||
#clear_history=True,
|
||||
message=prompt,
|
||||
silent=False,
|
||||
cache=None,
|
||||
)
|
||||
# reset usage summary for all agents after each command
|
||||
for agent in self.agents_map.values():
|
||||
if hasattr(agent, "client") and agent.client is not None:
|
||||
agent.client.clear_usage_summary() # type: ignore
|
||||
return result
|
||||
except openai.BadRequestError as bre:
|
||||
logger.error(f"Unable to process command: \"{command}\". {bre}")
|
||||
traceback.print_exc()
|
||||
|
22
Agent_E/ae/core/memory/static_ltm.py
Normal file
22
Agent_E/ae/core/memory/static_ltm.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import os
|
||||
|
||||
from Agent_E.ae.config import USER_PREFERENCES_PATH
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
def get_user_ltm():
|
||||
"""
|
||||
Get the user preferences stored in the user_preferences.txt file.
|
||||
returns: str | None - The user preferences stored in the user_preferences.txt file or None if not found.
|
||||
"""
|
||||
user_preferences_file_name = 'user_preferences.txt'
|
||||
user_preferences_file = os.path.join(USER_PREFERENCES_PATH, user_preferences_file_name)
|
||||
try:
|
||||
with open(user_preferences_file) as f:
|
||||
user_pref = f.read()
|
||||
logger.info(f"User preferences loaded from: {user_preferences_file}")
|
||||
return user_pref
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"""User preferences file \"{user_preferences_file_name}\" not found.
|
||||
To add your preferences for this agent to use, create a file called "{user_preferences_file_name}" in directory "{USER_PREFERENCES_PATH}".\n""")
|
||||
return None
|
53
Agent_E/ae/core/notification_manager.py
Normal file
53
Agent_E/ae/core/notification_manager.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
from collections.abc import Callable
|
||||
|
||||
|
||||
class NotificationManager:
|
||||
"""
|
||||
NotificationManager handles the dispatching of notifications to registered listeners.
|
||||
|
||||
Attributes:
|
||||
listeners (list[Callable[[dict[str, str]], None]]): A list of listener callbacks to notify.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the NotificationManager with no listeners.
|
||||
"""
|
||||
self.listeners: list[Callable[[dict[str, str]], None]] = []
|
||||
|
||||
def notify(self, message: str, message_type: str) -> None:
|
||||
"""
|
||||
Notify all registered listeners with a message and its type.
|
||||
|
||||
Args:
|
||||
message (str): The message to notify.
|
||||
message_type (str): The type of the message.
|
||||
"""
|
||||
notification = {
|
||||
"message": message,
|
||||
"type": message_type,
|
||||
}
|
||||
|
||||
if self.listeners:
|
||||
for listener in self.listeners:
|
||||
listener(notification)
|
||||
else:
|
||||
print(f"No listeners available, discarding message: {notification}")
|
||||
|
||||
def register_listener(self, listener: Callable[[dict[str, str]], None]) -> None:
|
||||
"""
|
||||
Register a new listener to receive notifications.
|
||||
|
||||
Args:
|
||||
listener (Callable[[dict[str, str]], None]): The listener callback to register.
|
||||
"""
|
||||
self.listeners.append(listener)
|
||||
|
||||
def unregister_listener(self, listener: Callable[[dict[str, str]], None]) -> None:
|
||||
"""
|
||||
Unregister a listener from receiving notifications.
|
||||
|
||||
Args:
|
||||
listener (Callable[[dict[str, str]], None]): The listener callback to unregister.
|
||||
"""
|
||||
self.listeners.remove(listener)
|
452
Agent_E/ae/core/playwright_manager.py
Normal file
452
Agent_E/ae/core/playwright_manager.py
Normal file
|
@ -0,0 +1,452 @@
|
|||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from playwright.async_api import async_playwright as playwright
|
||||
from playwright.async_api import BrowserContext
|
||||
from playwright.async_api import Page
|
||||
from playwright.async_api import Playwright
|
||||
|
||||
from Agent_E.ae.core.notification_manager import NotificationManager
|
||||
from Agent_E.ae.core.ui_manager import UIManager
|
||||
from Agent_E.ae.utils.dom_mutation_observer import dom_mutation_change_detected
|
||||
from Agent_E.ae.utils.dom_mutation_observer import handle_navigation_for_mutation_observer
|
||||
from Agent_E.ae.utils.js_helper import beautify_plan_message
|
||||
from Agent_E.ae.utils.js_helper import escape_js_message
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
# Enusres that playwright does not wait for font loading when taking screenshots. Reference: https://github.com/microsoft/playwright/issues/28995
|
||||
os.environ["PW_TEST_SCREENSHOT_NO_FONTS_READY"] = "1"
|
||||
|
||||
class PlaywrightManager:
|
||||
"""
|
||||
A singleton class to manage Playwright instances and browsers.
|
||||
|
||||
Attributes:
|
||||
browser_type (str): The type of browser to use ('chromium', 'firefox', 'webkit').
|
||||
isheadless (bool): Flag to launch the browser in headless mode or not.
|
||||
|
||||
The class ensures only one instance of itself, Playwright, and the browser is created during the application lifecycle.
|
||||
"""
|
||||
_homepage = "https://www.google.com"
|
||||
_instance = None
|
||||
_playwright = None # type: ignore
|
||||
_browser_context = None
|
||||
__async_initialize_done = False
|
||||
_take_screenshots = False
|
||||
_screenshots_dir = None
|
||||
|
||||
def __new__(cls, *args, **kwargs): # type: ignore
|
||||
"""
|
||||
Ensures that only one instance of PlaywrightManager is created (singleton pattern).
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance.__initialized = False
|
||||
logger.debug("Playwright instance created..")
|
||||
return cls._instance
|
||||
|
||||
|
||||
def __init__(self, browser_type: str = "chromium", headless: bool = False, gui_input_mode: bool = True, screenshots_dir: str = "", take_screenshots: bool = False):
|
||||
"""
|
||||
Initializes the PlaywrightManager with the specified browser type and headless mode.
|
||||
Initialization occurs only once due to the singleton pattern.
|
||||
|
||||
Args:
|
||||
browser_type (str, optional): The type of browser to use. Defaults to "chromium".
|
||||
headless (bool, optional): Flag to launch the browser in headless mode or not. Defaults to False (non-headless).
|
||||
"""
|
||||
if self.__initialized:
|
||||
return
|
||||
self.browser_type = browser_type
|
||||
self.isheadless = headless
|
||||
self.__initialized = True
|
||||
self.notification_manager = NotificationManager()
|
||||
self.user_response_event = asyncio.Event()
|
||||
if gui_input_mode:
|
||||
self.ui_manager: UIManager = UIManager()
|
||||
|
||||
self.set_take_screenshots(take_screenshots)
|
||||
self.set_screenshots_dir(screenshots_dir)
|
||||
|
||||
|
||||
async def async_initialize(self):
|
||||
"""
|
||||
Asynchronously initialize necessary components and handlers for the browser context.
|
||||
"""
|
||||
if self.__async_initialize_done:
|
||||
return
|
||||
|
||||
# Step 1: Ensure Playwright is started and browser context is created
|
||||
await self.start_playwright()
|
||||
await self.ensure_browser_context()
|
||||
|
||||
# Step 2: Deferred setup of handlers
|
||||
await self.setup_handlers()
|
||||
|
||||
# Step 3: Navigate to homepage
|
||||
await self.go_to_homepage()
|
||||
|
||||
self.__async_initialize_done = True
|
||||
|
||||
|
||||
async def ensure_browser_context(self):
|
||||
"""
|
||||
Ensure that a browser context exists, creating it if necessary.
|
||||
"""
|
||||
if self._browser_context is None:
|
||||
await self.create_browser_context()
|
||||
|
||||
|
||||
async def setup_handlers(self):
|
||||
"""
|
||||
Setup various handlers after the browser context has been ensured.
|
||||
"""
|
||||
await self.set_overlay_state_handler()
|
||||
await self.set_user_response_handler()
|
||||
await self.set_navigation_handler()
|
||||
|
||||
|
||||
async def start_playwright(self):
|
||||
"""
|
||||
Starts the Playwright instance if it hasn't been started yet. This method is idempotent.
|
||||
"""
|
||||
if not PlaywrightManager._playwright:
|
||||
PlaywrightManager._playwright: Playwright = await playwright().start()
|
||||
|
||||
|
||||
async def stop_playwright(self):
|
||||
"""
|
||||
Stops the Playwright instance and resets it to None. This method should be called to clean up resources.
|
||||
"""
|
||||
# Close the browser context if it's initialized
|
||||
if PlaywrightManager._browser_context is not None:
|
||||
await PlaywrightManager._browser_context.close()
|
||||
PlaywrightManager._browser_context = None
|
||||
|
||||
# Stop the Playwright instance if it's initialized
|
||||
if PlaywrightManager._playwright is not None: # type: ignore
|
||||
await PlaywrightManager._playwright.stop()
|
||||
PlaywrightManager._playwright = None # type: ignore
|
||||
|
||||
|
||||
async def create_browser_context(self):
|
||||
user_dir:str = os.environ.get('BROWSER_STORAGE_DIR', '')
|
||||
if self.browser_type == "chromium":
|
||||
logger.info(f"User dir: {user_dir}")
|
||||
try:
|
||||
PlaywrightManager._browser_context = await PlaywrightManager._playwright.chromium.launch_persistent_context(user_dir,
|
||||
channel= "chrome", headless=self.isheadless,
|
||||
args=["--disable-blink-features=AutomationControlled",
|
||||
"--disable-session-crashed-bubble", # disable the restore session bubble
|
||||
"--disable-infobars", # disable informational popups,
|
||||
],
|
||||
no_viewport=True
|
||||
)
|
||||
except Exception as e:
|
||||
if "Target page, context or browser has been closed" in str(e):
|
||||
new_user_dir = tempfile.mkdtemp()
|
||||
logger.error(f"Failed to launch persistent context with user dir {user_dir}: {e} Trying to launch with a new user dir {new_user_dir}")
|
||||
PlaywrightManager._browser_context = await PlaywrightManager._playwright.chromium.launch_persistent_context(new_user_dir,
|
||||
channel= "chrome", headless=self.isheadless,
|
||||
args=["--disable-blink-features=AutomationControlled",
|
||||
"--disable-session-crashed-bubble", # disable the restore session bubble
|
||||
"--disable-infobars", # disable informational popups,
|
||||
],
|
||||
no_viewport=True
|
||||
)
|
||||
elif "Chromium distribution 'chrome' is not found " in str(e):
|
||||
raise ValueError("Chrome is not installed on this device. Install Google Chrome or install playwright using 'playwright install chrome'. Refer to the readme for more information.") from None
|
||||
else:
|
||||
raise e from None
|
||||
else:
|
||||
raise ValueError(f"Unsupported browser type: {self.browser_type}")
|
||||
|
||||
|
||||
async def get_browser_context(self):
|
||||
"""
|
||||
Returns the existing browser context, or creates a new one if it doesn't exist.
|
||||
"""
|
||||
await self.ensure_browser_context()
|
||||
return self._browser_context
|
||||
|
||||
|
||||
async def get_current_url(self) -> str | None:
|
||||
"""
|
||||
Get the current URL of current page
|
||||
|
||||
Returns:
|
||||
str | None: The current URL if any.
|
||||
"""
|
||||
try:
|
||||
current_page: Page =await self.get_current_page()
|
||||
return current_page.url
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def get_current_page(self) -> Page :
|
||||
"""
|
||||
Get the current page of the browser
|
||||
|
||||
Returns:
|
||||
Page: The current page if any.
|
||||
"""
|
||||
try:
|
||||
browser: BrowserContext = await self.get_browser_context() # type: ignore
|
||||
# Filter out closed pages
|
||||
pages: list[Page] = [page for page in browser.pages if not page.is_closed()]
|
||||
page: Page | None = pages[-1] if pages else None
|
||||
logger.debug(f"Current page: {page.url if page else None}")
|
||||
if page is not None:
|
||||
return page
|
||||
else:
|
||||
page:Page = await browser.new_page() # type: ignore
|
||||
return page
|
||||
except Exception:
|
||||
logger.warn("Browser context was closed. Creating a new one.")
|
||||
PlaywrightManager._browser_context = None
|
||||
_browser:BrowserContext= await self.get_browser_context() # type: ignore
|
||||
page: Page | None = await self.get_current_page()
|
||||
return page
|
||||
|
||||
|
||||
async def close_all_tabs(self, keep_first_tab: bool = True):
|
||||
"""
|
||||
Closes all tabs in the browser context, except for the first tab if `keep_first_tab` is set to True.
|
||||
|
||||
Args:
|
||||
keep_first_tab (bool, optional): Whether to keep the first tab open. Defaults to True.
|
||||
"""
|
||||
browser_context = await self.get_browser_context()
|
||||
pages: list[Page] = browser_context.pages #type: ignore
|
||||
pages_to_close: list[Page] = pages[1:] if keep_first_tab else pages # type: ignore
|
||||
for page in pages_to_close: # type: ignore
|
||||
await page.close() # type: ignore
|
||||
|
||||
|
||||
async def close_except_specified_tab(self, page_to_keep: Page):
|
||||
"""
|
||||
Closes all tabs in the browser context, except for the specified tab.
|
||||
|
||||
Args:
|
||||
page_to_keep (Page): The Playwright page object representing the tab that should remain open.
|
||||
"""
|
||||
browser_context = await self.get_browser_context()
|
||||
for page in browser_context.pages: # type: ignore
|
||||
if page != page_to_keep: # Check if the current page is not the one to keep
|
||||
await page.close() # type: ignore
|
||||
|
||||
|
||||
async def go_to_homepage(self):
|
||||
page:Page = await PlaywrightManager.get_current_page(self)
|
||||
await page.goto(self._homepage)
|
||||
|
||||
|
||||
async def set_navigation_handler(self):
|
||||
page:Page = await PlaywrightManager.get_current_page(self)
|
||||
page.on("domcontentloaded", self.ui_manager.handle_navigation) # type: ignore
|
||||
page.on("domcontentloaded", handle_navigation_for_mutation_observer) # type: ignore
|
||||
await page.expose_function("dom_mutation_change_detected", dom_mutation_change_detected) # type: ignore
|
||||
|
||||
async def set_overlay_state_handler(self):
|
||||
logger.debug("Setting overlay state handler")
|
||||
context = await self.get_browser_context()
|
||||
await context.expose_function('overlay_state_changed', self.overlay_state_handler) # type: ignore
|
||||
await context.expose_function('show_steps_state_changed',self.show_steps_state_handler) # type: ignore
|
||||
|
||||
async def overlay_state_handler(self, is_collapsed: bool):
|
||||
page = await self.get_current_page()
|
||||
self.ui_manager.update_overlay_state(is_collapsed)
|
||||
if not is_collapsed:
|
||||
await self.ui_manager.update_overlay_chat_history(page)
|
||||
|
||||
async def show_steps_state_handler(self, show_details: bool):
|
||||
page = await self.get_current_page()
|
||||
await self.ui_manager.update_overlay_show_details(show_details, page)
|
||||
|
||||
async def set_user_response_handler(self):
|
||||
context = await self.get_browser_context()
|
||||
await context.expose_function('user_response', self.receive_user_response) # type: ignore
|
||||
|
||||
|
||||
async def notify_user(self, message: str, message_type: MessageType = MessageType.STEP):
|
||||
"""
|
||||
Notify the user with a message.
|
||||
|
||||
Args:
|
||||
message (str): The message to notify the user with.
|
||||
message_type (enum, optional): Values can be 'PLAN', 'QUESTION', 'ANSWER', 'INFO', 'STEP'. Defaults to 'STEP'.
|
||||
To Do: Convert to Enum.
|
||||
"""
|
||||
|
||||
if message.startswith(":"):
|
||||
message = message[1:]
|
||||
|
||||
if message.endswith(","):
|
||||
message = message[:-1]
|
||||
|
||||
if message_type == MessageType.PLAN:
|
||||
message = beautify_plan_message(message)
|
||||
message = "Plan:\n" + message
|
||||
elif message_type == MessageType.STEP:
|
||||
if "confirm" in message.lower():
|
||||
message = "Verify: " + message
|
||||
else:
|
||||
message = "Next step: " + message
|
||||
elif message_type == MessageType.QUESTION:
|
||||
message = "Question: " + message
|
||||
elif message_type == MessageType.ANSWER:
|
||||
message = "Response: " + message
|
||||
|
||||
safe_message = escape_js_message(message)
|
||||
self.ui_manager.new_system_message(safe_message, message_type)
|
||||
|
||||
if self.ui_manager.overlay_show_details == False: # noqa: E712
|
||||
if message_type not in (MessageType.PLAN, MessageType.QUESTION, MessageType.ANSWER, MessageType.INFO):
|
||||
return
|
||||
|
||||
if self.ui_manager.overlay_show_details == True: # noqa: E712
|
||||
if message_type not in (MessageType.PLAN, MessageType.QUESTION , MessageType.ANSWER, MessageType.INFO, MessageType.STEP):
|
||||
return
|
||||
|
||||
safe_message_type = escape_js_message(message_type.value)
|
||||
try:
|
||||
js_code = f"addSystemMessage({safe_message}, is_awaiting_user_response=false, message_type={safe_message_type});"
|
||||
page = await self.get_current_page()
|
||||
await page.evaluate(js_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to notify user with message \"{message}\". However, most likey this will work itself out after the page loads: {e}")
|
||||
|
||||
self.notification_manager.notify(message, message_type.value)
|
||||
|
||||
async def highlight_element(self, selector: str, add_highlight: bool):
|
||||
try:
|
||||
page: Page = await self.get_current_page()
|
||||
if add_highlight:
|
||||
# Add the 'agente-ui-automation-highlight' class to the element. This class is used to apply the fading border.
|
||||
await page.eval_on_selector(selector, '''e => {
|
||||
let originalBorderStyle = e.style.border;
|
||||
e.classList.add('agente-ui-automation-highlight');
|
||||
e.addEventListener('animationend', () => {
|
||||
e.classList.remove('agente-ui-automation-highlight')
|
||||
});}''')
|
||||
logger.debug(f"Applied pulsating border to element with selector {selector} to indicate text entry operation")
|
||||
else:
|
||||
# Remove the 'agente-ui-automation-highlight' class from the element.
|
||||
await page.eval_on_selector(selector, "e => e.classList.remove('agente-ui-automation-highlight')")
|
||||
logger.debug(f"Removed pulsating border from element with selector {selector} after text entry operation")
|
||||
except Exception:
|
||||
# This is not significant enough to fail the operation
|
||||
pass
|
||||
|
||||
async def receive_user_response(self, response: str):
|
||||
self.user_response = response # Store the response for later use.
|
||||
logger.debug(f"Received user response to system prompt: {response}")
|
||||
# Notify event loop that the user's response has been received.
|
||||
self.user_response_event.set()
|
||||
|
||||
|
||||
async def prompt_user(self, message: str) -> str:
|
||||
"""
|
||||
Prompt the user with a message and wait for a response.
|
||||
|
||||
Args:
|
||||
message (str): The message to prompt the user with.
|
||||
|
||||
Returns:
|
||||
str: The user's response.
|
||||
"""
|
||||
logger.debug(f"Prompting user with message: \"{message}\"")
|
||||
#self.ui_manager.new_system_message(message)
|
||||
|
||||
page = await self.get_current_page()
|
||||
|
||||
await self.ui_manager.show_overlay(page)
|
||||
self.log_system_message(message, MessageType.QUESTION) # add the message to history after the overlay is opened to avoid double adding it. add_system_message below will add it
|
||||
|
||||
safe_message = escape_js_message(message)
|
||||
|
||||
js_code = f"addSystemMessage({safe_message}, is_awaiting_user_response=true, message_type='question');"
|
||||
await page.evaluate(js_code)
|
||||
|
||||
await self.user_response_event.wait()
|
||||
result = self.user_response
|
||||
logger.info(f"User prompt reponse to \"{message}\": {result}")
|
||||
self.user_response_event.clear()
|
||||
self.user_response = ""
|
||||
self.ui_manager.new_user_message(result)
|
||||
return result
|
||||
|
||||
def set_take_screenshots(self, take_screenshots: bool):
|
||||
self._take_screenshots = take_screenshots
|
||||
|
||||
def get_take_screenshots(self):
|
||||
return self._take_screenshots
|
||||
|
||||
def set_screenshots_dir(self, screenshots_dir: str):
|
||||
self._screenshots_dir = screenshots_dir
|
||||
|
||||
def get_screenshots_dir(self):
|
||||
return self._screenshots_dir
|
||||
|
||||
async def take_screenshots(self, name: str, page: Page|None, full_page: bool = True, include_timestamp: bool = True,
|
||||
load_state: str = 'domcontentloaded', take_snapshot_timeout: int = 5*1000):
|
||||
if not self._take_screenshots:
|
||||
return
|
||||
if page is None:
|
||||
page = await self.get_current_page()
|
||||
|
||||
screenshot_name = name
|
||||
|
||||
if include_timestamp:
|
||||
screenshot_name = f"{int(time.time_ns())}_{screenshot_name}"
|
||||
screenshot_name += ".png"
|
||||
screenshot_path = f"{self.get_screenshots_dir()}/{screenshot_name}"
|
||||
try:
|
||||
await page.wait_for_load_state(state=load_state, timeout=take_snapshot_timeout) # type: ignore
|
||||
await page.screenshot(path=screenshot_path, full_page=full_page, timeout=take_snapshot_timeout, caret="initial", scale="device")
|
||||
logger.debug(f"Screen shot saved to: {screenshot_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to take screenshot and save to \"{screenshot_path}\". Error: {e}")
|
||||
|
||||
|
||||
def log_user_message(self, message: str):
|
||||
"""
|
||||
Log the user's message.
|
||||
|
||||
Args:
|
||||
message (str): The user's message to log.
|
||||
"""
|
||||
self.ui_manager.new_user_message(message)
|
||||
|
||||
|
||||
def log_system_message(self, message: str, type: MessageType = MessageType.STEP):
|
||||
"""
|
||||
Log a system message.
|
||||
|
||||
Args:
|
||||
message (str): The system message to log.
|
||||
"""
|
||||
self.ui_manager.new_system_message(message, type)
|
||||
|
||||
async def update_processing_state(self, processing_state: str):
|
||||
"""
|
||||
Update the processing state of the overlay.
|
||||
|
||||
Args:
|
||||
is_processing (str): "init", "processing", "done"
|
||||
"""
|
||||
page = await self.get_current_page()
|
||||
|
||||
await self.ui_manager.update_processing_state(processing_state, page)
|
||||
|
||||
async def command_completed(self, command: str, elapsed_time: float | None = None):
|
||||
"""
|
||||
Notify the overlay that the command has been completed.
|
||||
"""
|
||||
logger.debug(f"Command \"{command}\" has been completed. Focusing on the overlay input if it is open.")
|
||||
page = await self.get_current_page()
|
||||
await self.ui_manager.command_completed(page, command, elapsed_time)
|
43
Agent_E/ae/core/post_process_responses.py
Normal file
43
Agent_E/ae/core/post_process_responses.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import autogen # type: ignore
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
def final_reply_callback_user_proxy(recipient: autogen.ConversableAgent, messages: list[dict[str, Any]], sender: autogen.Agent, config: dict[str, Any]):
|
||||
"""
|
||||
Callback function that is called each time the user proxy agent receives a message.
|
||||
It picks the last message from the list of messages and checks if it contains the termination signal.
|
||||
If the termination signal is found, it extracts the final response and outputs it.
|
||||
|
||||
Args:
|
||||
recipient (autogen.ConversableAgent): The recipient of the message.
|
||||
messages (Optional[list[dict[str, Any]]]): The list of messages received by the agent.
|
||||
sender (Optional[autogen.Agent]): The sender of the message.
|
||||
config (Optional[Any]): Additional configuration parameters.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, None]: A tuple indicating whether the processing should stop and the response to be sent.
|
||||
"""
|
||||
global last_agent_response
|
||||
last_message = messages[-1]
|
||||
logger.debug(f"Post Process Message (User Proxy):{last_message}")
|
||||
if last_message.get('content') and "##TERMINATE##" in last_message['content']:
|
||||
last_agent_response = last_message['content'].replace("##TERMINATE##", "").strip()
|
||||
if last_agent_response:
|
||||
logger.debug("*****Final Reply*****")
|
||||
logger.debug(f"Final Response: {last_agent_response}")
|
||||
logger.debug("*********************")
|
||||
return True, None
|
||||
|
||||
return False, None
|
||||
|
||||
def final_reply_callback_planner_agent(message:str, message_type:MessageType = MessageType.STEP): # type: ignore
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(browser_manager.notify_user(message, message_type=message_type))
|
||||
return False, None # required to ensure the agent communication flow continues
|
185
Agent_E/ae/core/prompts.py
Normal file
185
Agent_E/ae/core/prompts.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
LLM_PROMPTS = {
|
||||
"USER_AGENT_PROMPT": """A proxy for the user for executing the user commands.""",
|
||||
"BROWSER_NAV_EXECUTOR_PROMPT": """A proxy for the user for executing the user commands.""",
|
||||
|
||||
"PLANNER_AGENT_PROMPT": """You are a web automation task planner. You will receive tasks from the user and will work with a naive helper to accomplish it.
|
||||
You will think step by step and break down the tasks into sequence of simple subtasks. Subtasks will be delegated to the helper to execute.
|
||||
|
||||
Return Format:
|
||||
Your reply will strictly be a well-fromatted JSON with four attributes.
|
||||
"plan": This is a string that contains the high-level plan. This is optional and needs to be present only when a task starts and when the plan needs to be revised.
|
||||
"next_step": This is a string that contains a detailed next step that is consistent with the plan. The next step will be delegated to the helper to execute. This needs to be present for every response except when terminating
|
||||
"terminate": yes/no. Return yes when the exact task is complete without any compromises or you are absolutely convinced that the task cannot be completed, no otherwise. This is mandatory for every response.
|
||||
"final_response": This is the final answer string that will be returned to the user. In search tasks, unless explicitly stated, you will provide the single best suited result in the response instead of listing multiple options. This attribute only needs to be present when terminate is true.
|
||||
|
||||
Capabilities and limitation of the helper:
|
||||
1. Helper can navigate to urls, perform simple interactions on a page or answer any question you may have about the current page.
|
||||
2. Helper cannot perform complex planning, reasoning or analysis. You will not delegate any such tasks to helper, instead you will perform them based on information from the helper.
|
||||
3. Helper is stateless and treats each step as a new task. Helper will not remember previous pages or actions. So, you will provide all necessary information as part of each step.
|
||||
4. Very Important: Helper cannot go back to previous pages. If you need the helper to return to a previous page, you must explicitly add the URL of the previous page in the step (e.g. return to the search result page by navigating to the url https://www.google.com/search?q=Finland")
|
||||
|
||||
Guidelines:
|
||||
1. If you know the direct URL, use it directly instead of searching for it (e.g. go to www.espn.com). Optimise the plan to avoid unnecessary steps.
|
||||
2. Do not assume any capability exists on the webpage. Ask questions to the helper to confirm the presence of features (e.g. is there a sort by price feature available on the page?). This will help you revise the plan as needed and also establish common ground with the helper.
|
||||
3. Do not combine multiple steps into one. A step should be strictly as simple as interacting with a single element or navigating to a page. If you need to interact with multiple elements or perform multiple actions, you will break it down into multiple steps.
|
||||
4. Important: You will NOT ask for any URLs of hyperlinks in the page from the helper, instead you will simply ask the helper to click on specific result. URL of the current page will be automatically provided to you with each helper response.
|
||||
5. Very Important: Add verification as part of the plan, after each step and specifically before terminating to ensure that the task is completed successfully. Ask simple questions to verify the step completion (e.g. Can you confirm that White Nothing Phone 2 with 16GB RAM is present in the cart?). Do not assume the helper has performed the task correctly.
|
||||
6. If the task requires multiple informations, all of them are equally important and should be gathered before terminating the task. You will strive to meet all the requirements of the task.
|
||||
7. If one plan fails, you MUST revise the plan and try a different approach. You will NOT terminate a task untill you are absolutely convinced that the task is impossible to accomplish.
|
||||
|
||||
Complexities of web navigation:
|
||||
1. Many forms have mandatory fields that need to be filled up before they can be submitted. Ask the helper for what fields look mandatory.
|
||||
2. In many websites, there are multiple options to filter or sort results. Ask the helper to list any elements on the page which will help the task (e.g. are there any links or interactive elements that may lead me to the support page?).
|
||||
3. Always keep in mind complexities such as filtering, advanced search, sorting, and other features that may be present on the website. Ask the helper whether these features are available on the page when relevant and use them when the task requires it.
|
||||
4. Very often list of items such as, search results, list of products, list of reviews, list of people etc. may be divided into multiple pages. If you need complete information, it is critical to explicitly ask the helper to go through all the pages.
|
||||
5. Sometimes search capabilities available on the page will not yield the optimal results. Revise the search query to either more specific or more generic.
|
||||
6. When a page refreshes or navigates to a new page, information entered in the previous page may be lost. Check that the information needs to be re-entered (e.g. what are the values in source and destination on the page?).
|
||||
7. Sometimes some elements may not be visible or be disabled until some other action is performed. Ask the helper to confirm if there are any other fields that may need to be interacted for elements to appear or be enabled.
|
||||
|
||||
Example 1:
|
||||
Task: Find the cheapest premium economy flights from Helsinki to Stockholm on 15 March on Skyscanner. Current page: www.google.com
|
||||
{"plan":"1. Go to www.skyscanner.com.
|
||||
2. List the interaction options available on skyscanner page relevant for flight reservation along with their default values.
|
||||
3. Select the journey option to one-way (if not default).
|
||||
4. Set number of passengers to 1 (if not default).
|
||||
5. Set the departure date to 15 March 2025 (since 15 March 2024 is already past).
|
||||
6. Set ticket type to Economy Premium.
|
||||
7. Set from airport to ""Helsinki".
|
||||
8. Set destination airport to Stockhokm
|
||||
9. Confirm that current values in the source airport, destination airport and departure date fields are Helsinki, Stockholm and 15 August 2024 respectively.
|
||||
10. Click on the search button to get the search results.
|
||||
11. Confirm that you are on the search results page.
|
||||
12. Extract the price of the cheapest flight from Helsinki to Stokchol from the search results.",
|
||||
"next_step": "Go to https://www.skyscanner.com",
|
||||
"terminate":"no"},
|
||||
After the task is completed and when terminating:
|
||||
Your reply: {"terminate":"yes", "final_response": "The cheapest premium economy flight from Helsinki to Stockholm on 15 March 2025 is <flight details>."}
|
||||
|
||||
Notice above how there is confirmation after each step and how interaction (e.g. setting source and destination) with each element is a seperate step. Follow same pattern.
|
||||
Remember: you are a very very persistent planner who will try every possible strategy to accomplish the task perfectly.
|
||||
Revise search query if needed, ask for more information if needed, and always verify the results before terminating the task.
|
||||
Some basic information about the user: $basic_user_information""",
|
||||
|
||||
"BROWSER_AGENT_PROMPT": """You will perform web navigation tasks, which may include logging into websites and interacting with any web content using the functions made available to you.
|
||||
Use the provided DOM representation for element location or text summarization.
|
||||
Interact with pages using only the "mmid" attribute in DOM elements.
|
||||
You must extract mmid value from the fetched DOM, do not conjure it up.
|
||||
Execute function sequentially to avoid navigation timing issues. Once a task is completed, confirm completion with ##TERMINATE TASK##.
|
||||
The given actions are NOT parallelizable. They are intended for sequential execution.
|
||||
If you need to call multiple functions in a task step, call one function at a time. Wait for the function's response before invoking the next function. This is important to avoid collision.
|
||||
Strictly for search fields, submit the field by pressing Enter key. For other forms, click on the submit button.
|
||||
Unless otherwise specified, the task must be performed on the current page. Use openurl only when explicitly instructed to navigate to a new page with a url specified. If you do not know the URL ask for it.
|
||||
You will NOT provide any URLs of links on webpage. If user asks for URLs, you will instead provide the text of the hyperlink on the page and offer to click on it. This is very very important.
|
||||
When inputing information, remember to follow the format of the input field. For example, if the input field is a date field, you will enter the date in the correct format (e.g. YYYY-MM-DD), you may get clues from the placeholder text in the input field.
|
||||
if the task is ambigous or there are multiple options to choose from, you will ask the user for clarification. You will not make any assumptions.
|
||||
Individual function will reply with action success and if any changes were observed as a consequence. Adjust your approach based on this feedback.
|
||||
Once the task is completed or cannot be completed, return a short summary of the actions you performed to accomplish the task, and what worked and what did not. This should be followed by ##TERMINATE TASK##. Your reply will not contain any other information.
|
||||
Additionally, If task requires an answer, you will also provide a short and precise answer followed by ##TERMINATE TASK##.
|
||||
Ensure that user questions are answered from the DOM and not from memory or assumptions. To answer a question about textual information on the page, prefer to use text_only DOM type. To answer a question about interactive elements, use all_fields DOM type.
|
||||
Do not provide any mmid values in your response.
|
||||
Important: If you encounter an issues or is unsure how to proceed, simply ##TERMINATE TASK## and provide a detailed summary of the exact issue encountered.
|
||||
Do not repeat the same action multiple times if it fails. Instead, if something did not work after a few attempts, terminate the task.""",
|
||||
|
||||
|
||||
"VERFICATION_AGENT": """Given a conversation and a task, your task is to analyse the conversation and tell if the task is completed. If not, you need to tell what is not completed and suggest next steps to complete the task.""",
|
||||
"ENTER_TEXT_AND_CLICK_PROMPT": """This skill enters text into a specified element and clicks another element, both identified by their DOM selector queries.
|
||||
Ideal for seamless actions like submitting search queries, this integrated approach ensures superior performance over separate text entry and click commands.
|
||||
Successfully completes when both actions are executed without errors, returning True; otherwise, it provides False or an explanatory message of any failure encountered.
|
||||
Always prefer this dual-action skill for tasks that combine text input and element clicking to leverage its streamlined operation.""",
|
||||
|
||||
|
||||
"OPEN_URL_PROMPT": """Opens a specified URL in the web browser instance. Returns url of the new page if successful or appropriate error message if the page could not be opened.""",
|
||||
|
||||
|
||||
"GO_BACK_PROMPT": """Goes back to previous page in the browser history. Useful when correcting an incorrect action that led to a new page or when needing to revisit a previous page for information. Returns the full URL of the page after the back action is performed.""",
|
||||
|
||||
|
||||
"COMMAND_EXECUTION_PROMPT": """Execute the user task "$command" $current_url_prompt_segment""",
|
||||
|
||||
|
||||
"GET_USER_INPUT_PROMPT": """Get clarification by asking the user or wait for user to perform an action on webpage. This is useful e.g. when you encounter a login or captcha and requires the user to intervene. This skill will also be useful when task is ambigious and you need more clarification from the user (e.g. ["which source website to use to accomplish a task"], ["Enter your credentials on your webpage and type done to continue"]). Use this skill very sparingly and only when absolutely needed.""",
|
||||
|
||||
|
||||
"GET_DOM_WITHOUT_CONTENT_TYPE_PROMPT": """Retrieves the DOM of the current web browser page.
|
||||
Each DOM element will have an \"mmid\" attribute injected for ease of DOM interaction.
|
||||
Returns a minified representation of the HTML DOM where each HTML DOM Element has an attribute called \"mmid\" for ease of DOM query selection. When \"mmid\" attribute is available, use it for DOM query selectors.""",
|
||||
|
||||
|
||||
# This one below had all three content types including input_fields
|
||||
"GET_DOM_WITH_CONTENT_TYPE_PROMPT": """Retrieves the DOM of the current web site based on the given content type.
|
||||
The DOM representation returned contains items ordered in the same way they appear on the page. Keep this in mind when executing user requests that contain ordinals or numbered items.
|
||||
text_only - returns plain text representing all the text in the web site. Use this for any information retrieval task. This will contain the most complete textual information.
|
||||
input_fields - returns a JSON string containing a list of objects representing text input html elements with mmid attribute. Use this strictly for interaction purposes with text input fields.
|
||||
all_fields - returns a JSON string containing a list of objects representing all interactive elements and their attributes with mmid attribute. Use this strictly to identify and interact with any type of elements on page.
|
||||
If information is not available in one content type, you must try another content_type.""",
|
||||
|
||||
|
||||
"GET_ACCESSIBILITY_TREE": """Retrieves the accessibility tree of the current web site.
|
||||
The DOM representation returned contains items ordered in the same way they appear on the page. Keep this in mind when executing user requests that contain ordinals or numbered items.""",
|
||||
|
||||
|
||||
"CLICK_PROMPT": """Executes a click action on the element matching the given mmid attribute value. It is best to use mmid attribute as the selector.
|
||||
Returns Success if click was successful or appropriate error message if the element could not be clicked.""",
|
||||
|
||||
|
||||
"CLICK_PROMPT_ACCESSIBILITY": """Executes a click action on the element a name and role.
|
||||
Returns Success if click was successful or appropriate error message if the element could not be clicked.""",
|
||||
|
||||
|
||||
"GET_URL_PROMPT": """Get the full URL of the current web page/site. If the user command seems to imply an action that would be suitable for an already open website in their browser, use this to fetch current website URL.""",
|
||||
|
||||
|
||||
"ENTER_TEXT_PROMPT": """Single enter given text in the DOM element matching the given mmid attribute value. This will only enter the text and not press enter or anything else.
|
||||
Returns Success if text entry was successful or appropriate error message if text could not be entered.""",
|
||||
|
||||
|
||||
"CLICK_BY_TEXT_PROMPT": """Executes a click action on the element matching the text. If multiple text matches are found, it will click on all of them. Use this as last resort when all else fails.""",
|
||||
|
||||
"BULK_ENTER_TEXT_PROMPT": """Bulk enter text in multiple DOM fields. To be used when there are multiple fields to be filled on the same page.
|
||||
Enters text in the DOM elements matching the given mmid attribute value.
|
||||
The input will receive a list of objects containing the DOM query selector and the text to enter.
|
||||
This will only enter the text and not press enter or anything else.
|
||||
Returns each selector and the result for attempting to enter text.""",
|
||||
|
||||
|
||||
"PRESS_KEY_COMBINATION_PROMPT": """Presses the given key on the current web page.
|
||||
This is useful for pressing the enter button to submit a search query, PageDown to scroll, ArrowDown to change selection in a focussed list etc.""",
|
||||
|
||||
|
||||
"ADD_TO_MEMORY_PROMPT": """"Save any information that you may need later in this term memory. This could be useful for saving things to do, saving information for personalisation, or even saving information you may need in future for efficiency purposes E.g. Remember to call John at 5pm, This user likes Tesla company and considered buying shares, The user enrollment form is available in <url> etc.""",
|
||||
|
||||
"HOVER_PROMPT": """Hover on a element with the given mmid attribute value. Hovering on an element can reveal additional information such as a tooltip or trigger a dropdown menu with different navigation options.""",
|
||||
"GET_MEMORY_PROMPT": """Retrieve all the information previously stored in the memory""",
|
||||
|
||||
|
||||
"PRESS_ENTER_KEY_PROMPT": """Presses the enter key in the given html field. This is most useful on text input fields.""",
|
||||
|
||||
|
||||
"EXTRACT_TEXT_FROM_PDF_PROMPT": """Extracts text from a PDF file hosted at the given URL.""",
|
||||
|
||||
|
||||
"BROWSER_AGENT_NO_SKILLS_PROMPT": """You are an autonomous agent tasked with performing web navigation on a Playwright instance, including logging into websites and executing other web-based actions.
|
||||
You will receive user commands, formulate a plan and then write the PYTHON code that is needed for the task to be completed.
|
||||
It is possible that the code you are writing is for one step at a time in the plan. This will ensure proper execution of the task.
|
||||
Your operations must be precise and efficient, adhering to the guidelines provided below:
|
||||
1. **Asynchronous Code Execution**: Your tasks will often be asynchronous in nature, requiring careful handling. Wrap asynchronous operations within an appropriate async structure to ensure smooth execution.
|
||||
2. **Sequential Task Execution**: To avoid issues related to navigation timing, execute your actions in a sequential order. This method ensures that each step is completed before the next one begins, maintaining the integrity of your workflow. Some steps like navigating to a site will require a small amount of wait time after them to ensure they load correctly.
|
||||
3. **Error Handling and Debugging**: Implement error handling to manage exceptions gracefully. Should an error occur or if the task doesn't complete as expected, review your code, adjust as necessary, and retry. Use the console or logging for debugging purposes to track the progress and issues.
|
||||
4. **Using HTML DOM**: Do not assume what a DOM selector (web elements) might be. Rather, fetch the DOM to look for the selectors or fetch DOM inner text to answer a questions. This is crucial for accurate task execution. When you fetch the DOM, reason about its content to determine appropriate selectors or text that should be extracted. To fetch the DOM using playwright you can:
|
||||
- Fetch entire DOM using page.content() method. In the fetched DOM, consider if appropriate to remove entire sections of the DOM like `script`, `link` elements
|
||||
- Fetch DOM inner text only text_content = await page.evaluate("() => document.body.innerText || document.documentElement.innerText"). This is useful for information retrieval.
|
||||
5. **DOM Handling**: Never ever substring the extracted HTML DOM. You can remove entire sections/elements of the DOM like `script`, `link` elements if they are not needed for the task. This is crucial for accurate task execution.
|
||||
6. **Execution Verification**: After executing the user the given code, ensure that you verify the completion of the task. If the task is not completed, revise your plan then rewrite the code for that step.
|
||||
7. **Termination Protocol**: Once a task is verified as complete or if it's determined that further attempts are unlikely to succeed, conclude the operation and respond with `##TERMINATE##`, to indicate the end of the session. This signal should only be used when the task is fully completed or if there's a consensus that continuation is futile.
|
||||
8. **Code Modification and Retry Strategy**: If your initial code doesn't achieve the desired outcome, revise your approach based on the insights gained during the process. When DOM selectors you are using fail, fetch the DOM and reason about it to discover the right selectors.If there are timeouts, adjust increase times. Add other error handling mechanisms before retrying as needed.
|
||||
9. **Code Generation**: Generated code does not need documentation or usage examples. Assume that it is being executed by an autonomous agent acting on behalf of the user. Do not add placeholders in the code.
|
||||
10. **Browser Handling**: Do not user headless mode with playwright. Do not close the browser after every step or even after task completion. Leave it open.
|
||||
11. **Reponse**: Remember that you are communicating with an autonomous agent that does not reason. All it does is execute code. Only respond with code that it can execute unless you are terminating.
|
||||
12. **Playwrite Oddities**: There are certain things that Playwright does not do well:
|
||||
- page.wait_for_selector: When providing a timeout value, it will almost always timeout. Put that call in a try/except block and catch the timeout. If timeout occurs just move to the next statement in the code and most likely it will work. For example, if next statement is page.fill, just execute it.
|
||||
|
||||
|
||||
By following these guidelines, you will enhance the efficiency, reliability, and user interaction of your web navigation tasks.
|
||||
Always aim for clear, concise, and well-structured code that aligns with best practices in asynchronous programming and web automation.
|
||||
""",
|
||||
}
|
18
Agent_E/ae/core/skills/__init__.py
Normal file
18
Agent_E/ae/core/skills/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from Agent_E.ae.core.skills.click_using_selector import click
|
||||
from Agent_E.ae.core.skills.click_using_selector import do_click
|
||||
from Agent_E.ae.core.skills.click_using_selector import is_element_present
|
||||
from Agent_E.ae.core.skills.click_using_selector import perform_javascript_click
|
||||
from Agent_E.ae.core.skills.click_using_selector import perform_playwright_click
|
||||
|
||||
from Agent_E.ae.core.skills.enter_text_and_click import enter_text_and_click
|
||||
|
||||
from Agent_E.ae.core.skills.enter_text_using_selector import bulk_enter_text
|
||||
from Agent_E.ae.core.skills.enter_text_using_selector import custom_fill_element
|
||||
from Agent_E.ae.core.skills.enter_text_using_selector import do_entertext
|
||||
|
||||
from Agent_E.ae.core.skills.get_dom_with_content_type import get_dom_with_content_type
|
||||
from Agent_E.ae.core.skills.get_url import geturl
|
||||
from Agent_E.ae.core.skills.get_user_input import get_user_input
|
||||
from Agent_E.ae.core.skills.open_url import openurl
|
||||
|
||||
from Agent_E.ae.core.skills.press_key_combination import press_key_combination
|
217
Agent_E/ae/core/skills/click_using_selector.py
Normal file
217
Agent_E/ae/core/skills/click_using_selector.py
Normal file
|
@ -0,0 +1,217 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
from typing import Annotated
|
||||
|
||||
from playwright.async_api import ElementHandle
|
||||
from playwright.async_api import Page
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.dom_helper import get_element_outer_html
|
||||
from Agent_E.ae.utils.dom_mutation_observer import subscribe # type: ignore
|
||||
from Agent_E.ae.utils.dom_mutation_observer import unsubscribe # type: ignore
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
async def click(selector: Annotated[str, "The properly formed query selector string to identify the element for the click action (e.g. [mmid='114']). When \"mmid\" attribute is present, use it for the query selector."],
|
||||
wait_before_execution: Annotated[float, "Optional wait time in seconds before executing the click event logic.", float] = 0.0) -> Annotated[str, "A message indicating success or failure of the click."]:
|
||||
"""
|
||||
Executes a click action on the element matching the given query selector string within the currently open web page.
|
||||
If there is no page open, it will raise a ValueError. An optional wait time can be specified before executing the click logic. Use this to wait for the page to load especially when the last action caused the DOM/Page to load.
|
||||
|
||||
Parameters:
|
||||
- selector: The query selector string to identify the element for the click action.
|
||||
- wait_before_execution: Optional wait time in seconds before executing the click event logic. Defaults to 0.0 seconds.
|
||||
|
||||
Returns:
|
||||
- Success if the click was successful, Appropropriate error message otherwise.
|
||||
"""
|
||||
logger.info(f"Executing ClickElement with \"{selector}\" as the selector")
|
||||
|
||||
# Initialize PlaywrightManager and get the active browser page
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
page = await browser_manager.get_current_page()
|
||||
|
||||
if page is None: # type: ignore
|
||||
raise ValueError('No active page found. OpenURL command opens a new page.')
|
||||
|
||||
function_name = inspect.currentframe().f_code.co_name # type: ignore
|
||||
|
||||
await browser_manager.take_screenshots(f"{function_name}_start", page)
|
||||
|
||||
await browser_manager.highlight_element(selector, True)
|
||||
|
||||
dom_changes_detected=None
|
||||
def detect_dom_changes(changes:str): # type: ignore
|
||||
nonlocal dom_changes_detected
|
||||
dom_changes_detected = changes # type: ignore
|
||||
|
||||
subscribe(detect_dom_changes)
|
||||
result = await do_click(page, selector, wait_before_execution)
|
||||
await asyncio.sleep(0.1) # sleep for 100ms to allow the mutation observer to detect changes
|
||||
unsubscribe(detect_dom_changes)
|
||||
await browser_manager.take_screenshots(f"{function_name}_end", page)
|
||||
await browser_manager.notify_user(result["summary_message"], message_type=MessageType.ACTION)
|
||||
|
||||
if dom_changes_detected:
|
||||
return f"Success: {result['summary_message']}.\n As a consequence of this action, new elements have appeared in view: {dom_changes_detected}. This means that the action to click {selector} is not yet executed and needs further interaction. Get all_fields DOM to complete the interaction."
|
||||
return result["detailed_message"]
|
||||
|
||||
|
||||
async def do_click(page: Page, selector: str, wait_before_execution: float) -> dict[str, str]:
|
||||
"""
|
||||
Executes the click action on the element with the given selector within the provided page.
|
||||
|
||||
Parameters:
|
||||
- page: The Playwright page instance.
|
||||
- selector: The query selector string to identify the element for the click action.
|
||||
- wait_before_execution: Optional wait time in seconds before executing the click event logic.
|
||||
|
||||
Returns:
|
||||
dict[str,str] - Explanation of the outcome of this operation represented as a dictionary with 'summary_message' and 'detailed_message'.
|
||||
"""
|
||||
logger.info(f"Executing ClickElement with \"{selector}\" as the selector. Wait time before execution: {wait_before_execution} seconds.")
|
||||
|
||||
# Wait before execution if specified
|
||||
if wait_before_execution > 0:
|
||||
await asyncio.sleep(wait_before_execution)
|
||||
|
||||
# Wait for the selector to be present and ensure it's attached and visible. If timeout, try javascript click
|
||||
try:
|
||||
logger.info(f"Executing ClickElement with \"{selector}\" as the selector. Waiting for the element to be attached and visible.")
|
||||
|
||||
element = await asyncio.wait_for(
|
||||
page.wait_for_selector(selector, state="attached", timeout=2000),
|
||||
timeout=2000
|
||||
)
|
||||
if element is None:
|
||||
raise ValueError(f"Element with selector: \"{selector}\" not found")
|
||||
|
||||
logger.info(f"Element with selector: \"{selector}\" is attached. scrolling it into view if needed.")
|
||||
try:
|
||||
await element.scroll_into_view_if_needed(timeout=200)
|
||||
logger.info(f"Element with selector: \"{selector}\" is attached and scrolled into view. Waiting for the element to be visible.")
|
||||
except Exception:
|
||||
# If scrollIntoView fails, just move on, not a big deal
|
||||
pass
|
||||
|
||||
try:
|
||||
await element.wait_for_element_state("visible", timeout=200)
|
||||
logger.info(f"Executing ClickElement with \"{selector}\" as the selector. Element is attached and visibe. Clicking the element.")
|
||||
except Exception:
|
||||
# If the element is not visible, try to click it anyway
|
||||
pass
|
||||
|
||||
element_tag_name = await element.evaluate("element => element.tagName.toLowerCase()")
|
||||
element_outer_html = await get_element_outer_html(element, page, element_tag_name)
|
||||
|
||||
|
||||
if element_tag_name == "option":
|
||||
element_value = await element.get_attribute("value") # get the text that is in the value of the option
|
||||
parent_element = await element.evaluate_handle("element => element.parentNode")
|
||||
# await parent_element.evaluate(f"element => element.select_option(value=\"{element_value}\")")
|
||||
await parent_element.select_option(value=element_value) # type: ignore
|
||||
|
||||
logger.info(f'Select menu option "{element_value}" selected')
|
||||
|
||||
return {"summary_message": f'Select menu option "{element_value}" selected',
|
||||
"detailed_message": f'Select menu option "{element_value}" selected. The select element\'s outer HTML is: {element_outer_html}.'}
|
||||
|
||||
|
||||
#Playwright click seems to fail more often than not, disabling it for now and just going with JS click
|
||||
#await perform_playwright_click(element, selector)
|
||||
msg = await perform_javascript_click(page, selector)
|
||||
return {"summary_message": msg, "detailed_message": f"{msg} The clicked element's outer HTML is: {element_outer_html}."} # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to click element with selector: \"{selector}\". Error: {e}")
|
||||
traceback.print_exc()
|
||||
msg = f"Unable to click element with selector: \"{selector}\" since the selector is invalid. Proceed by retrieving DOM again."
|
||||
return {"summary_message": msg, "detailed_message": f"{msg}. Error: {e}"}
|
||||
|
||||
|
||||
async def is_element_present(page: Page, selector: str) -> bool:
|
||||
"""
|
||||
Checks if an element is present on the page.
|
||||
|
||||
Parameters:
|
||||
- page: The Playwright page instance.
|
||||
- selector: The query selector string to identify the element.
|
||||
|
||||
Returns:
|
||||
- True if the element is present, False otherwise.
|
||||
"""
|
||||
element = await page.query_selector(selector)
|
||||
return element is not None
|
||||
|
||||
|
||||
async def perform_playwright_click(element: ElementHandle, selector: str):
|
||||
"""
|
||||
Performs a click action on the element using Playwright's click method.
|
||||
|
||||
Parameters:
|
||||
- element: The Playwright ElementHandle instance representing the element to be clicked.
|
||||
- selector: The query selector string of the element.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
logger.info(f"Performing first Step: Playwright Click on element with selector: {selector}")
|
||||
await element.click(force=False, timeout=200)
|
||||
|
||||
|
||||
async def perform_javascript_click(page: Page, selector: str):
|
||||
"""
|
||||
Performs a click action on the element using JavaScript.
|
||||
|
||||
Parameters:
|
||||
- page: The Playwright page instance.
|
||||
- selector: The query selector string of the element.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
"""
|
||||
js_code = """(selector) => {
|
||||
let element = document.querySelector(selector);
|
||||
|
||||
if (!element) {
|
||||
console.log(`perform_javascript_click: Element with selector ${selector} not found`);
|
||||
return `perform_javascript_click: Element with selector ${selector} not found`;
|
||||
}
|
||||
|
||||
if (element.tagName.toLowerCase() === "option") {
|
||||
let value = element.text;
|
||||
let parent = element.parentElement;
|
||||
|
||||
parent.value = element.value; // Directly set the value if possible
|
||||
// Trigger change event if necessary
|
||||
let event = new Event('change', { bubbles: true });
|
||||
parent.dispatchEvent(event);
|
||||
|
||||
console.log("Select menu option", value, "selected");
|
||||
return "Select menu option: "+ value+ " selected";
|
||||
}
|
||||
else {
|
||||
console.log("About to click selector", selector);
|
||||
// If the element is a link, make it open in the same tab
|
||||
if (element.tagName.toLowerCase() === "a") {
|
||||
element.target = "_self";
|
||||
}
|
||||
let ariaExpandedBeforeClick = element.getAttribute('aria-expanded');
|
||||
element.click();
|
||||
let ariaExpandedAfterClick = element.getAttribute('aria-expanded');
|
||||
if (ariaExpandedBeforeClick === 'false' && ariaExpandedAfterClick === 'true') {
|
||||
return "Executed JavaScript Click on element with selector: "+selector +". Very important: As a consequence a menu has appeared where you may need to make further selction. Very important: Get all_fields DOM to complete the action.";
|
||||
}
|
||||
return "Executed JavaScript Click on element with selector: "+selector;
|
||||
}
|
||||
}"""
|
||||
try:
|
||||
logger.info(f"Executing JavaScript click on element with selector: {selector}")
|
||||
result:str = await page.evaluate(js_code, selector)
|
||||
logger.debug(f"Executed JavaScript Click on element with selector: {selector}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing JavaScript click on element with selector: {selector}. Error: {e}")
|
||||
traceback.print_exc()
|
||||
|
82
Agent_E/ae/core/skills/enter_text_and_click.py
Normal file
82
Agent_E/ae/core/skills/enter_text_and_click.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
from typing import Annotated
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.core.skills.click_using_selector import do_click
|
||||
from Agent_E.ae.core.skills.enter_text_using_selector import do_entertext
|
||||
from Agent_E.ae.core.skills.press_key_combination import do_press_key_combination
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
async def enter_text_and_click(
|
||||
text_selector: Annotated[str, "The properly formatted DOM selector query, for example [mmid='1234'], where the text will be entered. Use mmid attribute."],
|
||||
text_to_enter: Annotated[str, "The text that will be entered into the element specified by text_selector."],
|
||||
click_selector: Annotated[str, "The properly formatted DOM selector query, for example [mmid='1234'], for the element that will be clicked after text entry."],
|
||||
wait_before_click_execution: Annotated[float, "Optional wait time in seconds before executing the click.", float] = 0.0
|
||||
) -> Annotated[str, "A message indicating success or failure of the text entry and click."]:
|
||||
"""
|
||||
Enters text into an element and then clicks on another element.
|
||||
|
||||
Parameters:
|
||||
- text_selector: The selector for the element to enter text into. It should be a properly formatted DOM selector query, for example [mmid='1234'], where the text will be entered. Use the mmid attribute.
|
||||
- text_to_enter: The text to enter into the element specified by text_selector.
|
||||
- click_selector: The selector for the element to click. It should be a properly formatted DOM selector query, for example [mmid='1234'].
|
||||
- wait_before_click_execution: Optional wait time in seconds before executing the click action. Default is 0.0.
|
||||
|
||||
Returns:
|
||||
- A message indicating the success or failure of the text entry and click.
|
||||
|
||||
Raises:
|
||||
- ValueError: If no active page is found. The OpenURL command opens a new page.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
await enter_text_and_click("[mmid='1234']", "Hello, World!", "[mmid='5678']", wait_before_click_execution=1.5)
|
||||
```
|
||||
"""
|
||||
logger.info(f"Entering text '{text_to_enter}' into element with selector '{text_selector}' and then clicking element with selector '{click_selector}'.")
|
||||
|
||||
# Initialize PlaywrightManager and get the active browser page
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
page = await browser_manager.get_current_page()
|
||||
if page is None: # type: ignore
|
||||
logger.error("No active page found")
|
||||
raise ValueError('No active page found. OpenURL command opens a new page.')
|
||||
|
||||
await browser_manager.highlight_element(text_selector, True)
|
||||
|
||||
function_name = inspect.currentframe().f_code.co_name # type: ignore
|
||||
await browser_manager.take_screenshots(f"{function_name}_start", page)
|
||||
|
||||
text_entry_result = await do_entertext(page, text_selector, text_to_enter, use_keyboard_fill=True)
|
||||
|
||||
#await browser_manager.notify_user(text_entry_result["summary_message"])
|
||||
if not text_entry_result["summary_message"].startswith("Success"):
|
||||
await browser_manager.take_screenshots(f"{function_name}_end", page)
|
||||
return(f"Failed to enter text '{text_to_enter}' into element with selector '{text_selector}'. Check that the selctor is valid.")
|
||||
|
||||
result = text_entry_result
|
||||
|
||||
#if the text_selector is the same as the click_selector, press the Enter key instead of clicking
|
||||
if text_selector == click_selector:
|
||||
do_press_key_combination_result = await do_press_key_combination(browser_manager, page, "Enter")
|
||||
if do_press_key_combination_result:
|
||||
result["detailed_message"] += f" Instead of click, pressed the Enter key successfully on element: \"{click_selector}\"."
|
||||
await browser_manager.notify_user(f"Pressed the Enter key successfully on element: \"{click_selector}\".", message_type=MessageType.ACTION)
|
||||
else:
|
||||
result["detailed_message"] += f" Clicking the same element after entering text in it, is of no value. Tried pressing the Enter key on element \"{click_selector}\" instead of click and failed."
|
||||
await browser_manager.notify_user("Failed to press the Enter key on element \"{click_selector}\".", message_type=MessageType.ACTION)
|
||||
else:
|
||||
await browser_manager.highlight_element(click_selector, True)
|
||||
|
||||
do_click_result = await do_click(page, click_selector, wait_before_click_execution)
|
||||
result["detailed_message"] += f' {do_click_result["detailed_message"]}'
|
||||
#await browser_manager.notify_user(do_click_result["summary_message"])
|
||||
|
||||
await asyncio.sleep(0.1) # sleep for 100ms to allow the mutation observer to detect changes
|
||||
|
||||
await browser_manager.take_screenshots(f"{function_name}_end", page)
|
||||
|
||||
return result["detailed_message"]
|
263
Agent_E/ae/core/skills/enter_text_using_selector.py
Normal file
263
Agent_E/ae/core/skills/enter_text_using_selector.py
Normal file
|
@ -0,0 +1,263 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
from typing import List # noqa: UP035
|
||||
|
||||
from playwright.async_api import Page
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.core.skills.press_key_combination import press_key_combination
|
||||
from Agent_E.ae.utils.dom_helper import get_element_outer_html
|
||||
from Agent_E.ae.utils.dom_mutation_observer import subscribe
|
||||
from Agent_E.ae.utils.dom_mutation_observer import unsubscribe
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnterTextEntry:
|
||||
"""
|
||||
Represents an entry for text input.
|
||||
|
||||
Attributes:
|
||||
query_selector (str): A valid DOM selector query. Use the mmid attribute.
|
||||
text (str): The text to enter in the element identified by the query_selector.
|
||||
"""
|
||||
|
||||
query_selector: str
|
||||
text: str
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
if key == "query_selector":
|
||||
return self.query_selector
|
||||
elif key == "text":
|
||||
return self.text
|
||||
else:
|
||||
raise KeyError(f"{key} is not a valid key")
|
||||
|
||||
|
||||
async def custom_fill_element(page: Page, selector: str, text_to_enter: str):
|
||||
"""
|
||||
Sets the value of a DOM element to a specified text without triggering keyboard input events.
|
||||
|
||||
This function directly sets the 'value' property of a DOM element identified by the given CSS selector,
|
||||
effectively changing its current value to the specified text. This approach bypasses the need for
|
||||
simulating keyboard typing, providing a more efficient and reliable way to fill in text fields,
|
||||
especially in automated testing scenarios where speed and accuracy are paramount.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright Page object representing the browser tab in which the operation will be performed.
|
||||
selector (str): The CSS selector string used to locate the target DOM element. The function will apply the
|
||||
text change to the first element that matches this selector.
|
||||
text_to_enter (str): The text value to be set in the target element. Existing content will be overwritten.
|
||||
|
||||
Example:
|
||||
await custom_fill_element(page, '#username', 'test_user')
|
||||
|
||||
Note:
|
||||
This function does not trigger input-related events (like 'input' or 'change'). If application logic
|
||||
relies on these events being fired, additional steps may be needed to simulate them.
|
||||
"""
|
||||
selector = f"{selector}" # Ensures the selector is treated as a string
|
||||
try:
|
||||
result = await page.evaluate(
|
||||
"""(inputParams) => {
|
||||
const selector = inputParams.selector;
|
||||
let text_to_enter = inputParams.text_to_enter;
|
||||
text_to_enter = text_to_enter.trim();
|
||||
const element = document.querySelector(selector);
|
||||
if (!element) {
|
||||
throw new Error(`Element not found: ${selector}`);
|
||||
}
|
||||
element.value = text_to_enter;
|
||||
return `Value set for ${selector}`;
|
||||
}""",
|
||||
{"selector": selector, "text_to_enter": text_to_enter},
|
||||
)
|
||||
logger.debug(f"custom_fill_element result: {result}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in custom_fill_element, Selector: {selector}, Text: {text_to_enter}. Error: {str(e)}")
|
||||
raise
|
||||
|
||||
async def entertext(entry: Annotated[EnterTextEntry, "An object containing 'query_selector' (DOM selector query using mmid attribute e.g. [mmid='114']) and 'text' (text to enter on the element)."]) -> Annotated[str, "Explanation of the outcome of this operation."]:
|
||||
"""
|
||||
Enters text into a DOM element identified by a CSS selector.
|
||||
|
||||
This function enters the specified text into a DOM element identified by the given CSS selector.
|
||||
It uses the Playwright library to interact with the browser and perform the text entry operation.
|
||||
The function supports both direct setting of the 'value' property and simulating keyboard typing.
|
||||
|
||||
Args:
|
||||
entry (EnterTextEntry): An object containing 'query_selector' (DOM selector query using mmid attribute)
|
||||
and 'text' (text to enter on the element).
|
||||
|
||||
Returns:
|
||||
str: Explanation of the outcome of this operation.
|
||||
|
||||
Example:
|
||||
entry = EnterTextEntry(query_selector='#username', text='test_user')
|
||||
result = await entertext(entry)
|
||||
|
||||
Note:
|
||||
- The 'query_selector' should be a valid CSS selector that uniquely identifies the target element.
|
||||
- The 'text' parameter specifies the text to be entered into the element.
|
||||
- The function uses the PlaywrightManager to manage the browser instance.
|
||||
- If no active page is found, an error message is returned.
|
||||
- The function internally calls the 'do_entertext' function to perform the text entry operation.
|
||||
- The 'do_entertext' function applies a pulsating border effect to the target element during the operation.
|
||||
- The 'use_keyboard_fill' parameter in 'do_entertext' determines whether to simulate keyboard typing or not.
|
||||
- If 'use_keyboard_fill' is set to True, the function uses the 'page.keyboard.type' method to enter the text.
|
||||
- If 'use_keyboard_fill' is set to False, the function uses the 'custom_fill_element' method to enter the text.
|
||||
"""
|
||||
logger.info(f"Entering text: {entry}")
|
||||
query_selector: str = entry['query_selector']
|
||||
text_to_enter: str = entry['text']
|
||||
|
||||
# Create and use the PlaywrightManager
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
page = await browser_manager.get_current_page()
|
||||
if page is None: # type: ignore
|
||||
return "Error: No active page found. OpenURL command opens a new page."
|
||||
|
||||
function_name = inspect.currentframe().f_code.co_name # type: ignore
|
||||
|
||||
await browser_manager.take_screenshots(f"{function_name}_start", page)
|
||||
|
||||
await browser_manager.highlight_element(query_selector, True)
|
||||
|
||||
dom_changes_detected=None
|
||||
def detect_dom_changes(changes:str): # type: ignore
|
||||
nonlocal dom_changes_detected
|
||||
dom_changes_detected = changes # type: ignore
|
||||
|
||||
subscribe(detect_dom_changes)
|
||||
|
||||
await page.evaluate(
|
||||
"""
|
||||
(selector) => {
|
||||
const element = document.querySelector(selector);
|
||||
if (element) {
|
||||
element.value = '';
|
||||
} else {
|
||||
console.error('Element not found:', selector);
|
||||
}
|
||||
}
|
||||
""",
|
||||
query_selector,
|
||||
)
|
||||
|
||||
result = await do_entertext(page, query_selector, text_to_enter)
|
||||
await asyncio.sleep(0.1) # sleep for 100ms to allow the mutation observer to detect changes
|
||||
unsubscribe(detect_dom_changes)
|
||||
|
||||
await browser_manager.take_screenshots(f"{function_name}_end", page)
|
||||
|
||||
await browser_manager.notify_user(result["summary_message"], message_type=MessageType.ACTION)
|
||||
if dom_changes_detected:
|
||||
return f"{result['detailed_message']}.\n As a consequence of this action, new elements have appeared in view: {dom_changes_detected}. This means that the action of entering text {text_to_enter} is not yet executed and needs further interaction. Get all_fields DOM to complete the interaction."
|
||||
return result["detailed_message"]
|
||||
|
||||
|
||||
async def do_entertext(page: Page, selector: str, text_to_enter: str, use_keyboard_fill: bool=True):
|
||||
"""
|
||||
Performs the text entry operation on a DOM element.
|
||||
|
||||
This function performs the text entry operation on a DOM element identified by the given CSS selector.
|
||||
It applies a pulsating border effect to the element during the operation for visual feedback.
|
||||
The function supports both direct setting of the 'value' property and simulating keyboard typing.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright Page object representing the browser tab in which the operation will be performed.
|
||||
selector (str): The CSS selector string used to locate the target DOM element.
|
||||
text_to_enter (str): The text value to be set in the target element. Existing content will be overwritten.
|
||||
use_keyboard_fill (bool, optional): Determines whether to simulate keyboard typing or not.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Explanation of the outcome of this operation represented as a dictionary with 'summary_message' and 'detailed_message'.
|
||||
|
||||
Example:
|
||||
result = await do_entertext(page, '#username', 'test_user')
|
||||
|
||||
Note:
|
||||
- The 'use_keyboard_fill' parameter determines whether to simulate keyboard typing or not.
|
||||
- If 'use_keyboard_fill' is set to True, the function uses the 'page.keyboard.type' method to enter the text.
|
||||
- If 'use_keyboard_fill' is set to False, the function uses the 'custom_fill_element' method to enter the text.
|
||||
"""
|
||||
try:
|
||||
|
||||
logger.debug(f"Looking for selector {selector} to enter text: {text_to_enter}")
|
||||
|
||||
elem = await page.query_selector(selector)
|
||||
|
||||
if elem is None:
|
||||
error = f"Error: Selector {selector} not found. Unable to continue."
|
||||
return {"summary_message": error, "detailed_message": error}
|
||||
|
||||
logger.info(f"Found selector {selector} to enter text")
|
||||
element_outer_html = await get_element_outer_html(elem, page)
|
||||
|
||||
if use_keyboard_fill:
|
||||
await elem.focus()
|
||||
await asyncio.sleep(0.1)
|
||||
await press_key_combination("Control+A")
|
||||
await asyncio.sleep(0.1)
|
||||
await press_key_combination("Backspace")
|
||||
await asyncio.sleep(0.1)
|
||||
logger.debug(f"Focused element with selector {selector} to enter text")
|
||||
#add a 100ms delay
|
||||
await page.keyboard.type(text_to_enter, delay=1)
|
||||
else:
|
||||
await custom_fill_element(page, selector, text_to_enter)
|
||||
await elem.focus()
|
||||
logger.info(f"Success. Text \"{text_to_enter}\" set successfully in the element with selector {selector}")
|
||||
success_msg = f"Success. Text \"{text_to_enter}\" set successfully in the element with selector {selector}"
|
||||
return {"summary_message": success_msg, "detailed_message": f"{success_msg} and outer HTML: {element_outer_html}."}
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error = f"Error entering text in selector {selector}."
|
||||
return {"summary_message": error, "detailed_message": f"{error} Error: {e}"}
|
||||
|
||||
|
||||
async def bulk_enter_text(
|
||||
entries: Annotated[List[dict[str, str]], "List of objects, each containing 'query_selector' and 'text'."] # noqa: UP006
|
||||
) -> Annotated[List[dict[str, str]], "List of dictionaries, each containing 'query_selector' and the result of the operation."]: # noqa: UP006
|
||||
"""
|
||||
Enters text into multiple DOM elements using a bulk operation.
|
||||
|
||||
This function enters text into multiple DOM elements using a bulk operation.
|
||||
It takes a list of dictionaries, where each dictionary contains a 'query_selector' and 'text' pair.
|
||||
The function internally calls the 'entertext' function to perform the text entry operation for each entry.
|
||||
|
||||
Args:
|
||||
entries: List of objects, each containing 'query_selector' and 'text'.
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each containing 'query_selector' and the result of the operation.
|
||||
|
||||
Example:
|
||||
entries = [
|
||||
{"query_selector": "#username", "text": "test_user"},
|
||||
{"query_selector": "#password", "text": "test_password"}
|
||||
]
|
||||
results = await bulk_enter_text(entries)
|
||||
|
||||
Note:
|
||||
- Each entry in the 'entries' list should be a dictionary with 'query_selector' and 'text' keys.
|
||||
- The result is a list of dictionaries, where each dictionary contains the 'query_selector' and the result of the operation.
|
||||
"""
|
||||
|
||||
results: List[dict[str, str]] = [] # noqa: UP006
|
||||
logger.info("Executing bulk Enter Text Command")
|
||||
for entry in entries:
|
||||
query_selector = entry['query_selector']
|
||||
text_to_enter = entry['text']
|
||||
logger.info(f"Entering text: {text_to_enter} in element with selector: {query_selector}")
|
||||
result = await entertext(EnterTextEntry(query_selector=query_selector, text=text_to_enter))
|
||||
|
||||
results.append({"query_selector": query_selector, "result": result})
|
||||
|
||||
return results
|
115
Agent_E/ae/core/skills/get_dom_with_content_type.py
Normal file
115
Agent_E/ae/core/skills/get_dom_with_content_type.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import os
|
||||
import time
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
|
||||
from playwright.async_api import Page
|
||||
|
||||
from Agent_E.ae.config import SOURCE_LOG_FOLDER_PATH
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.dom_helper import wait_for_non_loading_dom_state
|
||||
from Agent_E.ae.utils.get_detailed_accessibility_tree import do_get_accessibility_info
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
async def get_dom_with_content_type(
|
||||
content_type: Annotated[str, "The type of content to extract: 'text_only': Extracts the innerText of the highest element in the document and responds with text, or 'input_fields': Extracts the text input and button elements in the dom."]
|
||||
) -> Annotated[dict[str, Any] | str | None, "The output based on the specified content type."]:
|
||||
"""
|
||||
Retrieves and processes the DOM of the active page in a browser instance based on the specified content type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content_type : str
|
||||
The type of content to extract. Possible values are:
|
||||
- 'text_only': Extracts the innerText of the highest element in the document and responds with text.
|
||||
- 'input_fields': Extracts the text input and button elements in the DOM and responds with a JSON object.
|
||||
- 'all_fields': Extracts all the fields in the DOM and responds with a JSON object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any] | str | None
|
||||
The processed content based on the specified content type. This could be:
|
||||
- A JSON object for 'input_fields' with just inputs.
|
||||
- Plain text for 'text_only'.
|
||||
- A minified DOM represented as a JSON object for 'all_fields'.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If an unsupported content_type is provided.
|
||||
"""
|
||||
|
||||
logger.info(f"Executing Get DOM Command based on content_type: {content_type}")
|
||||
start_time = time.time()
|
||||
# Create and use the PlaywrightManager
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
page = await browser_manager.get_current_page()
|
||||
if page is None: # type: ignore
|
||||
raise ValueError('No active page found. OpenURL command opens a new page.')
|
||||
|
||||
extracted_data = None
|
||||
await wait_for_non_loading_dom_state(page, 2000) # wait for the DOM to be ready, non loading means external resources do not need to be loaded
|
||||
user_success_message = ""
|
||||
if content_type == 'all_fields':
|
||||
user_success_message = "Fetched all the fields in the DOM"
|
||||
extracted_data = await do_get_accessibility_info(page, only_input_fields=False)
|
||||
elif content_type == 'input_fields':
|
||||
logger.debug('Fetching DOM for input_fields')
|
||||
extracted_data = await do_get_accessibility_info(page, only_input_fields=True)
|
||||
if extracted_data is None:
|
||||
return "Could not fetch input fields. Please consider trying with content_type all_fields."
|
||||
user_success_message = "Fetched only input fields in the DOM"
|
||||
elif content_type == 'text_only':
|
||||
# Extract text from the body or the highest-level element
|
||||
logger.debug('Fetching DOM for text_only')
|
||||
text_content = await get_filtered_text_content(page)
|
||||
with open(os.path.join(SOURCE_LOG_FOLDER_PATH, 'text_only_dom.txt'), 'w', encoding='utf-8') as f:
|
||||
f.write(text_content)
|
||||
extracted_data = text_content
|
||||
user_success_message = "Fetched the text content of the DOM"
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {content_type}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"Get DOM Command executed in {elapsed_time} seconds")
|
||||
await browser_manager.notify_user(user_success_message, message_type=MessageType.ACTION)
|
||||
return extracted_data # type: ignore
|
||||
|
||||
|
||||
async def get_filtered_text_content(page: Page) -> str:
|
||||
text_content = await page.evaluate("""
|
||||
() => {
|
||||
// Array of query selectors to filter out
|
||||
const selectorsToFilter = ['#agente-overlay'];
|
||||
|
||||
// Store the original visibility values to revert later
|
||||
const originalStyles = [];
|
||||
|
||||
// Hide the elements matching the query selectors
|
||||
selectorsToFilter.forEach(selector => {
|
||||
const elements = document.querySelectorAll(selector);
|
||||
elements.forEach(element => {
|
||||
originalStyles.push({ element: element, originalStyle: element.style.visibility });
|
||||
element.style.visibility = 'hidden';
|
||||
});
|
||||
});
|
||||
|
||||
// Get the text content of the page
|
||||
let textContent = document?.body?.innerText || document?.documentElement?.innerText || "";
|
||||
|
||||
// Get all the alt text from images on the page
|
||||
let altTexts = Array.from(document.querySelectorAll('img')).map(img => img.alt);
|
||||
altTexts="Other Alt Texts in the page: " + altTexts.join(' ');
|
||||
|
||||
// Revert the visibility changes
|
||||
originalStyles.forEach(entry => {
|
||||
entry.element.style.visibility = entry.originalStyle;
|
||||
});
|
||||
textContent=textContent+" "+altTexts;
|
||||
return textContent;
|
||||
}
|
||||
""")
|
||||
return text_content
|
||||
|
40
Agent_E/ae/core/skills/get_url.py
Normal file
40
Agent_E/ae/core/skills/get_url.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from typing import Annotated
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
|
||||
|
||||
async def geturl() -> Annotated[str, "Returns the full URL of the current active web site/page."]:
|
||||
"""
|
||||
Returns the full URL of the current page
|
||||
|
||||
Parameters:
|
||||
|
||||
Returns:
|
||||
- Full URL the browser's active page.
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
# Create and use the PlaywrightManager
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
page = await browser_manager.get_current_page()
|
||||
|
||||
if not page:
|
||||
raise ValueError('No active page found. OpenURL command opens a new page.')
|
||||
|
||||
await page.wait_for_load_state("domcontentloaded")
|
||||
|
||||
# Get the URL of the current page
|
||||
try:
|
||||
title = await page.title()
|
||||
current_url = page.url
|
||||
if len(current_url) >250:
|
||||
current_url = current_url[:250] + "..."
|
||||
return f"Current Page: {current_url}, Title: {title}" # type: ignore
|
||||
except: # noqa: E722
|
||||
current_url = page.url
|
||||
return f"Current Page: {current_url}"
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError('No active page found. OpenURL command opens a new page.') from e
|
||||
|
26
Agent_E/ae/core/skills/get_user_input.py
Normal file
26
Agent_E/ae/core/skills/get_user_input.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from typing import Annotated
|
||||
from typing import List # noqa: UP035
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.cli_helper import answer_questions_over_cli
|
||||
|
||||
|
||||
async def get_user_input(questions: Annotated[List[str], "List of questions to ask the user each one represented as a string"] ) -> dict[str, str]: # noqa: UP006
|
||||
"""
|
||||
Asks the user a list of questions and returns the answers in a dictionary.
|
||||
|
||||
Parameters:
|
||||
- questions: A list of questions to ask the user ["What is Username?", "What is your password?"].
|
||||
|
||||
Returns:
|
||||
- Newline separated list of questions to ask the user
|
||||
"""
|
||||
|
||||
answers: dict[str, str] = {}
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
if browser_manager.ui_manager:
|
||||
for question in questions:
|
||||
answers[question] = await browser_manager.prompt_user(f"Question: {question}")
|
||||
else:
|
||||
answers = await answer_questions_over_cli(questions)
|
||||
return answers
|
70
Agent_E/ae/core/skills/open_url.py
Normal file
70
Agent_E/ae/core/skills/open_url.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
import inspect
|
||||
from typing import Annotated
|
||||
|
||||
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
async def openurl(url: Annotated[str, "The URL to navigate to. Value must include the protocol (http:// or https://)."],
|
||||
timeout: Annotated[int, "Additional wait time in seconds after initial load."] = 3) -> Annotated[str, "Returns the result of this request in text form"]:
|
||||
"""
|
||||
Opens a specified URL in the active browser instance. Waits for an initial load event, then waits for either
|
||||
the 'domcontentloaded' event or a configurable timeout, whichever comes first.
|
||||
|
||||
Parameters:
|
||||
- url: The URL to navigate to.
|
||||
- timeout: Additional time in seconds to wait after the initial load before considering the navigation successful.
|
||||
|
||||
Returns:
|
||||
- URL of the new page.
|
||||
"""
|
||||
logger.info(f"Opening URL: {url}")
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
await browser_manager.get_browser_context()
|
||||
page = await browser_manager.get_current_page()
|
||||
try:
|
||||
url = ensure_protocol(url)
|
||||
if page.url == url:
|
||||
logger.info(f"Current page URL is the same as the new URL: {url}. No need to refresh.")
|
||||
title = await page.title()
|
||||
return f"Page already loaded: {url}, Title: {title}" # type: ignore
|
||||
|
||||
# Navigate to the URL with a short timeout to ensure the initial load starts
|
||||
function_name = inspect.currentframe().f_code.co_name # type: ignore
|
||||
|
||||
await browser_manager.take_screenshots(f"{function_name}_start", page)
|
||||
|
||||
await page.goto(url, timeout=timeout*1000) # type: ignore
|
||||
except PlaywrightTimeoutError as pte:
|
||||
logger.warn(f"Initial navigation to {url} failed: {pte}. Will try to continue anyway.") # happens more often than not, but does not seem to be a problem
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while opening the URL: {url}. Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
await browser_manager.take_screenshots(f"{function_name}_end", page)
|
||||
|
||||
await browser_manager.notify_user(f"Opened URL: {url}", message_type=MessageType.ACTION)
|
||||
# Get the page title
|
||||
title = await page.title()
|
||||
url=page.url
|
||||
return f"Page loaded: {url}, Title: {title}" # type: ignore
|
||||
|
||||
def ensure_protocol(url: str) -> str:
|
||||
"""
|
||||
Ensures that a URL has a protocol (http:// or https://). If it doesn't have one,
|
||||
https:// is added by default.
|
||||
|
||||
Parameters:
|
||||
- url: The URL to check and modify if necessary.
|
||||
|
||||
Returns:
|
||||
- A URL string with a protocol.
|
||||
"""
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = 'https://' + url # Default to http if no protocol is specified
|
||||
logger.info(f"Added 'https://' protocol to URL because it was missing. New URL is: {url}")
|
||||
return url
|
88
Agent_E/ae/core/skills/pdf_text_extractor.py
Normal file
88
Agent_E/ae/core/skills/pdf_text_extractor.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
import pdfplumber
|
||||
|
||||
from Agent_E.ae.config import PROJECT_TEMP_PATH
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
async def extract_text_from_pdf(pdf_url: Annotated[str, "The URL of the PDF file to extract text from."]) -> Annotated[str, "All the text found in the PDF file."]:
|
||||
"""
|
||||
Extract text from a PDF file.
|
||||
pdf_url: str - The URL of the PDF file to extract text from.
|
||||
returns: str - All the text found in the PDF.
|
||||
"""
|
||||
file_path = os.path.join(PROJECT_TEMP_PATH, "downloaded_file.pdf") # fixed file path for downloading the PDF
|
||||
|
||||
try:
|
||||
# Create and use the PlaywrightManager
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
|
||||
# Download the PDF
|
||||
download_result = await download_pdf(pdf_url, file_path)
|
||||
if not os.path.exists(download_result):
|
||||
return download_result # Return error message if download failed
|
||||
|
||||
# Open the PDF using pdfplumber and extract text
|
||||
text = ""
|
||||
with pdfplumber.open(download_result) as pdf:
|
||||
for page in pdf.pages:
|
||||
page_text = page.extract_text()
|
||||
if page_text:
|
||||
text += page_text + "\n"
|
||||
extracted_text = text.strip()
|
||||
word_count = len(extracted_text.split())
|
||||
await browser_manager.notify_user(f"Extracted text from the PDF successfully. Found {word_count} words.", message_type=MessageType.ACTION)
|
||||
return "Text found in the PDF:\n" + extracted_text
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"An error occurred while downloading the PDF from {pdf_url}: {str(e)}")
|
||||
return f"An error occurred while downloading the PDF: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while extracting text from the PDF that was downloaded from {pdf_url}: {str(e)}")
|
||||
return f"An error occurred while extracting text: {str(e)}"
|
||||
finally:
|
||||
# Cleanup: Ensure the downloaded file is removed
|
||||
cleanup_temp_files(file_path)
|
||||
|
||||
def cleanup_temp_files(*file_paths: str) -> None:
|
||||
"""
|
||||
Remove the specified temporary files.
|
||||
|
||||
*file_paths: str - One or more file paths to be removed.
|
||||
"""
|
||||
for file_path in file_paths:
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.debug(f"Cleaned file from the filesystem: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove {file_path}: {str(e)}")
|
||||
else:
|
||||
logger.debug(f"File not found. Unable to clean it from the filesystem: {file_path}")
|
||||
|
||||
async def download_pdf(pdf_url: str, file_path: str) -> str:
|
||||
"""
|
||||
Download the PDF file from the given URL and save it to the specified path.
|
||||
|
||||
pdf_url: str - The URL of the PDF file to download.
|
||||
file_path: str - The local path to save the downloaded PDF.
|
||||
|
||||
returns: str - The file path of the downloaded PDF if successful, otherwise an error message.
|
||||
raises: Exception - If an error occurs during the download process.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Downloading PDF from: {pdf_url} to: {file_path}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(pdf_url)
|
||||
response.raise_for_status() # Ensure the request was successful
|
||||
with open(file_path, 'wb') as pdf_file:
|
||||
pdf_file.write(response.content)
|
||||
return file_path
|
||||
# except httpx.HTTPStatusError as e:
|
||||
# raise e
|
||||
except Exception as e:
|
||||
raise e
|
111
Agent_E/ae/core/skills/press_key_combination.py
Normal file
111
Agent_E/ae/core/skills/press_key_combination.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
from typing import Annotated
|
||||
|
||||
from playwright.async_api import Page # type: ignore
|
||||
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.dom_mutation_observer import subscribe # type: ignore
|
||||
from Agent_E.ae.utils.dom_mutation_observer import unsubscribe # type: ignore
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
async def press_key_combination(key_combination: Annotated[str, "The key to press, e.g., Enter, PageDown etc"]) -> str:
|
||||
"""
|
||||
Presses a key combination on the current active page managed by PlaywrightManager.
|
||||
|
||||
This function simulates the pressing of a key or a combination of keys on the current active web page.
|
||||
The `key_combination` should be a string that represents the keys to be pressed, separated by '+' if it's a combination.
|
||||
For example, 'Control+C' to copy or 'Alt+F4' to close a window on Windows.
|
||||
|
||||
Parameters:
|
||||
- key_combination (Annotated[str, "The key combination to press, e.g., 'Control+C'."]): The key combination to press, represented as a string. For combinations, use '+' as a separator.
|
||||
|
||||
Raises:
|
||||
- ValueError: If no active page is found.
|
||||
|
||||
Returns:
|
||||
str: status of the operation expressed as a string
|
||||
"""
|
||||
|
||||
logger.info(f"Executing press_key_combination with key combo: {key_combination}")
|
||||
# Create and use the PlaywrightManager
|
||||
browser_manager = PlaywrightManager()
|
||||
page = await browser_manager.get_current_page()
|
||||
|
||||
if page is None: # type: ignore
|
||||
raise ValueError('No active page found. OpenURL command opens a new page.')
|
||||
|
||||
# Split the key combination if it's a combination of keys
|
||||
keys = key_combination.split('+')
|
||||
|
||||
dom_changes_detected=None
|
||||
def detect_dom_changes(changes:str): # type: ignore
|
||||
nonlocal dom_changes_detected
|
||||
dom_changes_detected = changes # type: ignore
|
||||
|
||||
subscribe(detect_dom_changes)
|
||||
# If it's a combination, hold down the modifier keys
|
||||
for key in keys[:-1]: # All keys except the last one are considered modifier keys
|
||||
await page.keyboard.down(key)
|
||||
|
||||
# Press the last key in the combination
|
||||
await page.keyboard.press(keys[-1])
|
||||
|
||||
# Release the modifier keys
|
||||
for key in keys[:-1]:
|
||||
await page.keyboard.up(key)
|
||||
await asyncio.sleep(0.1) # sleep for 100ms to allow the mutation observer to detect changes
|
||||
unsubscribe(detect_dom_changes)
|
||||
|
||||
if dom_changes_detected:
|
||||
return f"Key {key_combination} executed successfully.\n As a consequence of this action, new elements have appeared in view:{dom_changes_detected}. This means that the action is not yet executed and needs further interaction. Get all_fields DOM to complete the interaction."
|
||||
|
||||
await browser_manager.notify_user(f"Key {key_combination} executed successfully", message_type=MessageType.ACTION)
|
||||
return f"Key {key_combination} executed successfully"
|
||||
|
||||
|
||||
async def do_press_key_combination(browser_manager: PlaywrightManager, page: Page, key_combination: str) -> bool:
|
||||
"""
|
||||
Presses a key combination on the provided page.
|
||||
|
||||
This function simulates the pressing of a key or a combination of keys on a web page.
|
||||
The `key_combination` should be a string that represents the keys to be pressed, separated by '+' if it's a combination.
|
||||
For example, 'Control+C' to copy or 'Alt+F4' to close a window on Windows.
|
||||
|
||||
Parameters:
|
||||
- browser_manager (PlaywrightManager): The PlaywrightManager instance.
|
||||
- page (Page): The Playwright page instance.
|
||||
- key_combination (str): The key combination to press, represented as a string. For combinations, use '+' as a separator.
|
||||
|
||||
Returns:
|
||||
bool: True if success and False if failed
|
||||
"""
|
||||
|
||||
logger.info(f"Executing press_key_combination with key combo: {key_combination}")
|
||||
try:
|
||||
function_name = inspect.currentframe().f_code.co_name # type: ignore
|
||||
await browser_manager.take_screenshots(f"{function_name}_start", page)
|
||||
# Split the key combination if it's a combination of keys
|
||||
keys = key_combination.split('+')
|
||||
|
||||
# If it's a combination, hold down the modifier keys
|
||||
for key in keys[:-1]: # All keys except the last one are considered modifier keys
|
||||
await page.keyboard.down(key)
|
||||
|
||||
# Press the last key in the combination
|
||||
await page.keyboard.press(keys[-1])
|
||||
|
||||
# Release the modifier keys
|
||||
for key in keys[:-1]:
|
||||
await page.keyboard.up(key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing press_key_combination \"{key_combination}\": {e}")
|
||||
return False
|
||||
|
||||
await browser_manager.take_screenshots(f"{function_name}_end", page)
|
||||
|
||||
return True
|
||||
|
29
Agent_E/ae/core/skills/skill_registry.py
Normal file
29
Agent_E/ae/core/skills/skill_registry.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# skill_registry.py
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
# Define the type of the functions that will be registered as skills
|
||||
SkillType = Callable[..., Any]
|
||||
|
||||
# Global registry to store private skill functions and their metadata
|
||||
skill_registry: list[dict[str, Any]] = []
|
||||
|
||||
def skill(description: str, name: str|None = None) -> Callable[[SkillType], SkillType]:
|
||||
"""
|
||||
Decorator for registering private skills.
|
||||
|
||||
Parameters:
|
||||
- description: A string describing the skill's function.
|
||||
- name: Optional name to register the skill with. If not provided, the function's name will be used.
|
||||
|
||||
Returns:
|
||||
- A decorator function that registers the skill in the global registry.
|
||||
"""
|
||||
def decorator(func: SkillType) -> SkillType:
|
||||
skill_registry.append({
|
||||
"name": name if name else func.__name__, # Use provided name or fallback to function name
|
||||
"func": func,
|
||||
"description": description
|
||||
})
|
||||
return func
|
||||
return decorator
|
227
Agent_E/ae/core/system_orchestrator.py
Normal file
227
Agent_E/ae/core/system_orchestrator.py
Normal file
|
@ -0,0 +1,227 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import Agent_E.ae.core.playwright_manager as browserManager
|
||||
from Agent_E.ae.config import SOURCE_LOG_FOLDER_PATH
|
||||
from Agent_E.ae.core.agents_llm_config import AgentsLLMConfig
|
||||
from Agent_E.ae.core.autogen_wrapper import AutogenWrapper
|
||||
from Agent_E.ae.utils.cli_helper import async_input # type: ignore
|
||||
from Agent_E.ae.utils.formatting_helper import str_to_bool
|
||||
from Agent_E.ae.utils.http_helper import make_post_request
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
class SystemOrchestrator:
|
||||
"""
|
||||
Orchestrates the system's operation, handling input from both a command prompt and a web interface,
|
||||
and coordinating between the Autogen wrapper and the Playwright manager.
|
||||
|
||||
Attributes:
|
||||
agent_scenario (str): The agent scenario to use for command processing. Defaults to "user_proxy,browser_nav_agent".
|
||||
input_mode (str): The input mode of the system, determining whether command prompt input is enabled. Defaults to "GUI_ONLY".
|
||||
browser_manager (PlaywrightManager): The Playwright manager instance for web interaction.
|
||||
autogen_wrapper (AutogenWrapper): The Autogen wrapper instance for agent-based command processing.
|
||||
is_running (bool): Flag indicating whether the system is currently processing a command.
|
||||
shutdown_event (asyncio.Event): Event to wait for an exit command to be processed.
|
||||
"""
|
||||
|
||||
def __init__(self, agent_scenario:str="user,planner_agent,browser_nav_agent,browser_nav_executor", input_mode:str="GUI_ONLY",
|
||||
planner_max_chat_round: int = 50, browser_nav_max_chat_round: int = 10):
|
||||
"""
|
||||
Initializes the system orchestrator with the specified agent scenario and input mode.
|
||||
|
||||
Args:
|
||||
agent_scenario (str, optional): The agent scenario to use for command processing. Defaults to "user_proxy,browser_nav_agent".
|
||||
input_mode (str, optional): The input mode of the system. Defaults to "GUI_ONLY".
|
||||
planner_max_chat_rounds (int, optional): The maximum number of chat rounds for the planner. Defaults to 50.
|
||||
browser_nav_max_chat_round (int, optional): The maximum number of chat rounds for the browser navigation agent. Defaults to 10.
|
||||
"""
|
||||
load_dotenv()
|
||||
self.planner_number_of_rounds = planner_max_chat_round
|
||||
self.browser_number_of_rounds = browser_nav_max_chat_round
|
||||
|
||||
self.agent_scenario = agent_scenario
|
||||
self.input_mode = input_mode
|
||||
self.browser_manager = None
|
||||
self.autogen_wrapper = None
|
||||
self.is_running = False
|
||||
|
||||
self.save_chat_logs_to_files = str_to_bool(os.getenv('SAVE_CHAT_LOGS_TO_FILE', True))
|
||||
|
||||
if os.getenv('ORCHESTRATOR_API_KEY', None) is not None and os.getenv('ORCHESTRATOR_GATEWAY', None) is not None:
|
||||
self.__populate_orchestrator_info()
|
||||
logger.info(f"Orchestrator endpoint: {self.orchestrator_endpoint}")
|
||||
else:
|
||||
self.use_orchestrator = False
|
||||
|
||||
self.__parse_user_and_browser_agent_names()
|
||||
self.shutdown_event = asyncio.Event() #waits for an exit command to be processed
|
||||
|
||||
|
||||
def __populate_orchestrator_info(self):
|
||||
"""
|
||||
Populates the orchestrator information by retrieving the API key, gateway, and endpoint from environment variables.
|
||||
"""
|
||||
self.orchestrator_api_key = os.getenv('ORCHESTRATOR_API_KEY')
|
||||
self.orchestrator_gateway = os.getenv('ORCHESTRATOR_GATEWAY')
|
||||
self.orchestrator_endpoint = f"{self.orchestrator_gateway}/api/orchestrate"
|
||||
self.use_orchestrator = True
|
||||
|
||||
|
||||
def __parse_user_and_browser_agent_names(self):
|
||||
"""
|
||||
Parse the user and browser agent names from agent_scenario
|
||||
"""
|
||||
self.agent_names = self.agent_scenario.split(',')
|
||||
for agent_name in self.agent_names:
|
||||
if 'user' in agent_name:
|
||||
self.ser_agent_name = agent_name
|
||||
elif 'planner' in agent_name:
|
||||
self.planner_agent_name = agent_name
|
||||
elif 'browser' in agent_name:
|
||||
self.browser_agent_name = agent_name
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
Initializes the components required for the system's operation, including the Autogen wrapper and the Playwright manager.
|
||||
"""
|
||||
# Load the configuration using AgentsLLMConfig
|
||||
llm_config = AgentsLLMConfig()
|
||||
|
||||
# Retrieve planner agent and browser nav agent configurations
|
||||
self.planner_agent_config = llm_config.get_planner_agent_config()
|
||||
self.browser_nav_agent_config = llm_config.get_browser_nav_agent_config()
|
||||
|
||||
self.autogen_wrapper = await AutogenWrapper.create(self.planner_agent_config, self.browser_nav_agent_config, agents_needed=self.agent_names,
|
||||
save_chat_logs_to_files=self.save_chat_logs_to_files,
|
||||
planner_max_chat_round=self.planner_number_of_rounds, browser_nav_max_chat_round=self.browser_number_of_rounds)
|
||||
|
||||
self.browser_manager = browserManager.PlaywrightManager(gui_input_mode=self.input_mode == "GUI_ONLY")
|
||||
await self.browser_manager.async_initialize()
|
||||
|
||||
if self.input_mode == "GUI_ONLY":
|
||||
browser_context = await self.browser_manager.get_browser_context()
|
||||
await browser_context.expose_function('process_task', self.receive_command) # type: ignore
|
||||
|
||||
async def start(self):
|
||||
"""
|
||||
Starts the system orchestrator, initializing components and starting the command prompt loop if necessary.
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
if self.input_mode != "GUI_ONLY":
|
||||
await self.command_prompt_loop()
|
||||
|
||||
await self.wait_for_exit()
|
||||
|
||||
async def command_prompt_loop(self):
|
||||
"""
|
||||
Continuously reads and processes commands from the command prompt until an 'exit' command is received.
|
||||
"""
|
||||
while not self.is_running:
|
||||
command: str = await async_input("Enter your command (or type 'exit' to quit): ") # type: ignore
|
||||
await self.process_command(command) # type: ignore
|
||||
|
||||
async def receive_command(self, command: str):
|
||||
"""
|
||||
Callback function to process commands received from the web interface.
|
||||
|
||||
Args:
|
||||
command (str): The command received from the web interface.
|
||||
"""
|
||||
await self.process_command(command)
|
||||
|
||||
async def __orchestrate_command(self, command: str):
|
||||
if not self.use_orchestrator:
|
||||
return command
|
||||
|
||||
orch_response = make_post_request(self.orchestrator_endpoint, {"query": command}, self.orchestrator_api_key, api_key_header_name="X-API-Key") # type: ignore
|
||||
|
||||
if not orch_response:
|
||||
return command
|
||||
|
||||
if "user_notification" in orch_response:
|
||||
await self.browser_manager.notify_user(orch_response["user_notification"]) # type: ignore
|
||||
if "is_terminating" in orch_response and orch_response["is_terminating"]:
|
||||
logger.info("Orchestrator indicated command execution completed.")
|
||||
return None
|
||||
if "reformulated_query" in orch_response:
|
||||
logger.info(f"Orchestrator reformulated command to: {orch_response['reformulated_query']}")
|
||||
return orch_response["reformulated_query"]
|
||||
|
||||
|
||||
async def process_command(self, command: str):
|
||||
"""
|
||||
Processes a given command, coordinating with the Autogen wrapper for execution and handling special commands like 'exit'.
|
||||
|
||||
Args:
|
||||
command (str): The command to process.
|
||||
"""
|
||||
logger.info(f"Received command: {command}")
|
||||
if command.lower() == 'exit':
|
||||
await self.shutdown()
|
||||
return
|
||||
|
||||
if command:
|
||||
self.is_running = True
|
||||
start_time = time.time()
|
||||
current_url = await self.browser_manager.get_current_url() if self.browser_manager else None
|
||||
self.browser_manager.ui_manager.clear_conversation_history() # type: ignore
|
||||
self.browser_manager.log_user_message(command) # type: ignore
|
||||
result = None
|
||||
logger.info(f"Processing command: {command}")
|
||||
if self.autogen_wrapper:
|
||||
await self.browser_manager.update_processing_state("processing") # type: ignore
|
||||
orchestrated_command = await self.__orchestrate_command(command)
|
||||
if orchestrated_command is not None:
|
||||
result = await self.autogen_wrapper.process_command(orchestrated_command, current_url)
|
||||
else:
|
||||
result = await self.autogen_wrapper.process_command(command, current_url)
|
||||
|
||||
await self.browser_manager.update_processing_state("done") # type: ignore
|
||||
end_time = time.time()
|
||||
elapsed_time = round(end_time - start_time, 2)
|
||||
logger.info(f"Command \"{command}\" took: {elapsed_time} seconds.")
|
||||
await self.save_planner_chat_messages()
|
||||
if result is not None:
|
||||
chat_history= result.chat_history # type: ignore
|
||||
last_message = chat_history[-1] if chat_history else None # type: ignore
|
||||
if last_message and "terminate" in last_message and last_message["terminate"]=="yes":
|
||||
await self.browser_manager.notify_user(last_message, "answer") # type: ignore
|
||||
|
||||
await self.browser_manager.notify_user(f"Task Completed ({elapsed_time}s).", "info") # type: ignore
|
||||
await self.browser_manager.command_completed(command, elapsed_time) # type: ignore
|
||||
self.is_running = False
|
||||
|
||||
async def save_planner_chat_messages(self):
|
||||
"""
|
||||
Saves the chat messages from the Autogen wrapper's agents to a JSON file.
|
||||
"""
|
||||
|
||||
messages = self.autogen_wrapper.agents_map[self.planner_agent_name].chat_messages # type: ignore
|
||||
messages_str_keys = {str(key): value for key, value in messages.items()} # type: ignore
|
||||
if self.save_chat_logs_to_files:
|
||||
with open(os.path.join(SOURCE_LOG_FOLDER_PATH, 'chat_messages.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(messages_str_keys, f, ensure_ascii=False, indent=4)
|
||||
logger.debug("Chat messages saved")
|
||||
else:
|
||||
logger.info("Planner chat log: ", extra={"planner_chat_log": messages_str_keys}) # type: ignore
|
||||
|
||||
async def wait_for_exit(self):
|
||||
"""
|
||||
Waits for an exit command to be processed, keeping the system active in the meantime.
|
||||
"""
|
||||
await self.shutdown_event.wait() # Wait until the shutdown event is set
|
||||
|
||||
async def shutdown(self):
|
||||
"""
|
||||
Shuts down the system orchestrator, stopping the Playwright manager and exiting the command prompt loop.
|
||||
"""
|
||||
logger.info("Shutting down System Orchestrator...")
|
||||
if self.browser_manager:
|
||||
await self.browser_manager.stop_playwright()
|
||||
self.shutdown_event.set() # Signal the shutdown event to stop waiting in wait_for_exit
|
221
Agent_E/ae/core/ui_manager.py
Normal file
221
Agent_E/ae/core/ui_manager.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from playwright.async_api import Frame
|
||||
from playwright.async_api import Page
|
||||
|
||||
from Agent_E.ae.config import PROJECT_SOURCE_ROOT
|
||||
from Agent_E.ae.utils.js_helper import escape_js_message
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
|
||||
class UIManager:
|
||||
"""
|
||||
Manages the UI overlay for this application. The application uses playwright for the browser driver.
|
||||
This includes handling navigation events, showing or hiding overlays, and maintaining
|
||||
a conversation history within the UI overlay.
|
||||
|
||||
Attributes:
|
||||
overlay_is_collapsed (bool): Indicates if the overlay is currently collapsed.
|
||||
conversation_history (list[dict[str, str]]): The chat history between user and system. Each entry contains 'from' and 'message' keys.
|
||||
__update_overlay_chat_history_running (bool): A flag to prevent concurrent updates to the chat history.
|
||||
"""
|
||||
|
||||
overlay_is_collapsed: bool = True
|
||||
|
||||
overlay_processing_state: str = "init" #init: initialised, processing: processing is ongoing, done: processing is done
|
||||
overlay_show_details:bool = True
|
||||
|
||||
conversation_history:list[dict[str, str]] = []
|
||||
__update_overlay_chat_history_running: bool = False
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the UIManager instance by adding default system messages to the conversation history.
|
||||
"""
|
||||
self.add_default_system_messages()
|
||||
|
||||
|
||||
async def handle_navigation(self, frame: Frame):
|
||||
"""
|
||||
Handles navigation events by injecting JavaScript code into the frame to manage the overlay state
|
||||
and updating the overlay chat history.
|
||||
|
||||
Args:
|
||||
frame (Frame): The Playwright Frame object to inject JavaScript into and manage.
|
||||
"""
|
||||
try:
|
||||
await frame.wait_for_load_state("load")
|
||||
overlay_injection_file = os.path.join(PROJECT_SOURCE_ROOT, "ui", "injectOverlay.js")
|
||||
with open(overlay_injection_file, 'r') as file: # noqa: UP015
|
||||
js_code = file.read()
|
||||
|
||||
# Inject the JavaScript code into the page
|
||||
await frame.evaluate(js_code)
|
||||
js_bool = str(self.overlay_show_details).lower()
|
||||
if self.overlay_is_collapsed:
|
||||
await frame.evaluate(f"showCollapsedOverlay('{self.overlay_processing_state}', {js_bool});")
|
||||
else:
|
||||
await frame.evaluate(f"showExpandedOverlay('{self.overlay_processing_state}', {js_bool});")
|
||||
|
||||
#update chat history in the overlay
|
||||
await self.update_overlay_chat_history(frame)
|
||||
|
||||
except Exception as e:
|
||||
if "Frame was detached" not in str(e):
|
||||
raise e
|
||||
|
||||
|
||||
async def show_overlay(self, page: Page):
|
||||
"""
|
||||
Displays the overlay in an expanded state on the given page if it's currently collapsed.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright Page object on which to show the overlay.
|
||||
"""
|
||||
if not self.overlay_is_collapsed:
|
||||
logger.debug("Overlay is already expanded, ignoring show_overlay call")
|
||||
return
|
||||
await page.evaluate("showExpandedOverlay();")
|
||||
self.overlay_is_collapsed = True
|
||||
|
||||
|
||||
def update_overlay_state(self, is_collapsed: bool):
|
||||
"""
|
||||
Updates the state of the overlay to either collapsed or expanded.
|
||||
|
||||
Args:
|
||||
is_collapsed (bool): True to collapse the overlay, False to expand it.
|
||||
"""
|
||||
self.overlay_is_collapsed = is_collapsed
|
||||
|
||||
|
||||
|
||||
async def update_overlay_show_details(self, show_details: bool, page: Page):
|
||||
"""
|
||||
Updates the state of the overlay to either show steps or not.
|
||||
|
||||
Args:
|
||||
show_steps (bool): True to show steps, False to hide them.
|
||||
"""
|
||||
self.overlay_show_details = show_details
|
||||
await self.update_overlay_chat_history(page)
|
||||
|
||||
|
||||
async def update_processing_state(self, state: str, page: Page):
|
||||
"""
|
||||
Updates the processing state of the overlay.
|
||||
|
||||
Args:
|
||||
state (str): The processing state to update.
|
||||
"""
|
||||
self.overlay_processing_state = state
|
||||
try:
|
||||
js_bool = str(self.overlay_is_collapsed).lower()
|
||||
await page.evaluate(f"updateOverlayState('{self.overlay_processing_state}', {js_bool});")
|
||||
except Exception as e:
|
||||
logger.debug(f"JavaScript error: {e}")
|
||||
|
||||
async def update_overlay_chat_history(self, frame_or_page: Frame | Page):
|
||||
"""
|
||||
Updates the chat history in the overlay. If the overlay is expanded and not currently being updated,
|
||||
it clears existing messages and adds them fresh from the conversation history.
|
||||
|
||||
Args:
|
||||
frame_or_page (Frame | Page): The Playwright Frame or Page object to update the chat history in.
|
||||
"""
|
||||
logger.debug("Updating overlay chat history")
|
||||
|
||||
if self.overlay_is_collapsed:
|
||||
logger.debug("Overlay is collapsed, not updating chat history")
|
||||
return
|
||||
if self.__update_overlay_chat_history_running:
|
||||
logger.debug("update_overlay_chat_history is already running, returning" + frame_or_page.url)
|
||||
return
|
||||
|
||||
self.__update_overlay_chat_history_running = True
|
||||
#update chat history in the overlay by removing all messages and adding them again fresh
|
||||
try:
|
||||
await frame_or_page.evaluate("clearOverlayMessages();")
|
||||
for message in self.conversation_history:
|
||||
safe_message = escape_js_message(message["message"])
|
||||
safe_message_type = escape_js_message(message.get("message_type", MessageType.STEP.value))
|
||||
if message["from"] == "user":
|
||||
await frame_or_page.evaluate(f"addUserMessage({safe_message});")
|
||||
else:
|
||||
#choose chich message types to be shown depending on UI setting
|
||||
if self.overlay_show_details == False: # noqa: E712
|
||||
if message["message_type"] not in (MessageType.PLAN.value, MessageType.QUESTION.value, MessageType.ANSWER.value, MessageType.INFO.value):
|
||||
continue
|
||||
else:
|
||||
if message["message_type"] not in (MessageType.PLAN.value, MessageType.QUESTION.value , MessageType.ANSWER.value, MessageType.INFO, MessageType.STEP.value):
|
||||
continue
|
||||
|
||||
js_code = f"addSystemMessage({safe_message}, is_awaiting_user_response=false, message_type={safe_message_type});"
|
||||
await frame_or_page.evaluate(js_code)
|
||||
logger.debug("Chat history updated in overlay, removing update lock flag")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.__update_overlay_chat_history_running = False
|
||||
|
||||
def clear_conversation_history(self):
|
||||
"""
|
||||
Clears the conversation history.
|
||||
"""
|
||||
self.conversation_history = []
|
||||
self.add_default_system_messages()
|
||||
|
||||
def get_conversation_history(self):
|
||||
"""
|
||||
Returns the current conversation history.
|
||||
|
||||
Returns:
|
||||
list[dict[str, str]]: The conversation history.
|
||||
"""
|
||||
return self.conversation_history
|
||||
|
||||
|
||||
def new_user_message(self, message: str):
|
||||
"""
|
||||
Adds a new user message to the conversation history.
|
||||
|
||||
Args:
|
||||
message (str): The message text to add.
|
||||
"""
|
||||
|
||||
self.conversation_history.append({"from":"user", "message":message})
|
||||
|
||||
|
||||
def new_system_message(self, message: str, type:MessageType=MessageType.STEP):
|
||||
"""
|
||||
Adds a new system message to the conversation history.
|
||||
|
||||
Args:
|
||||
message (str): The message text to add.
|
||||
"""
|
||||
|
||||
self.conversation_history.append({"from":"system", "message":message, "message_type":type.value})
|
||||
print(f"Adding system message: {message}")
|
||||
|
||||
def add_default_system_messages(self):
|
||||
"""
|
||||
Adds default system messages to the conversation history to greet the user or provide initial instructions.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def command_completed(self, page: Page, command: str, elapsed_time: float|None = None):
|
||||
"""
|
||||
Handles the completion of a command, focusing on the overlay input and indicating that the command has finished.
|
||||
|
||||
Args:
|
||||
page (Page): The Playwright Page object where the command was executed.
|
||||
command (str): The command that was completed.
|
||||
elapsed_time (float | None, optional): The time taken to complete the command, if relevant.
|
||||
"""
|
||||
if not self.overlay_is_collapsed:
|
||||
await page.evaluate("focusOnOverlayInput();")
|
||||
await page.evaluate("commandExecutionCompleted();")
|
0
Agent_E/ae/server/__init__.py
Normal file
0
Agent_E/ae/server/__init__.py
Normal file
191
Agent_E/ae/server/api_routes.py
Normal file
191
Agent_E/ae/server/api_routes.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from queue import Empty
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
import Agent_E.ae.core.playwright_manager as browserManager
|
||||
from Agent_E.ae.config import SOURCE_LOG_FOLDER_PATH
|
||||
from Agent_E.ae.core.agents_llm_config import AgentsLLMConfig
|
||||
from Agent_E.ae.core.autogen_wrapper import AutogenWrapper
|
||||
from Agent_E.ae.utils.formatting_helper import is_terminating_message
|
||||
from Agent_E.ae.utils.ui_messagetype import MessageType
|
||||
|
||||
browser_manager = browserManager.PlaywrightManager(headless=False)
|
||||
|
||||
APP_VERSION = "1.0.0"
|
||||
APP_NAME = "Agent-E Web API"
|
||||
API_PREFIX = "/api"
|
||||
IS_DEBUG = False
|
||||
HOST = os.getenv("HOST", "0.0.0.0")
|
||||
PORT = int(os.getenv("PORT", 8080))
|
||||
WORKERS = 1
|
||||
|
||||
container_id = os.getenv("CONTAINER_ID", "")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
|
||||
class CommandQueryModel(BaseModel):
|
||||
command: str = Field(..., description="The command related to web navigation to execute.") # Required field with description
|
||||
llm_config: dict[str,Any] | None = Field(None, description="The LLM configuration string to use for the agents.")
|
||||
planner_max_chat_round: int = Field(50, description="The maximum number of chat rounds for the planner.")
|
||||
browser_nav_max_chat_round: int = Field(10, description="The maximum number of chat rounds for the browser navigation agent.")
|
||||
clientid: str | None = Field(None, description="Client identifier, optional")
|
||||
request_originator: str | None = Field(None, description="Optional id of the request originator")
|
||||
|
||||
|
||||
def get_app() -> FastAPI:
|
||||
"""Starts the Application"""
|
||||
fast_app = FastAPI(title=APP_NAME, version=APP_VERSION, debug=IS_DEBUG)
|
||||
|
||||
fast_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
return fast_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
|
||||
|
||||
@app.on_event("startup") # type: ignore
|
||||
async def startup_event():
|
||||
"""
|
||||
Startup event handler to initialize browser manager asynchronously.
|
||||
"""
|
||||
global container_id
|
||||
|
||||
if container_id.strip() == "":
|
||||
container_id = str(uuid.uuid4())
|
||||
os.environ["CONTAINER_ID"] = container_id
|
||||
await browser_manager.async_initialize()
|
||||
|
||||
|
||||
@app.post("/execute_task", description="Execute a given command related to web navigation and return the result.")
|
||||
async def execute_task(request: Request, query_model: CommandQueryModel):
|
||||
notification_queue = Queue() # type: ignore
|
||||
transaction_id = str(uuid.uuid4()) if query_model.clientid is None else query_model.clientid
|
||||
register_notification_listener(notification_queue)
|
||||
return StreamingResponse(run_task(request, transaction_id, query_model.command, browser_manager, notification_queue, query_model.request_originator,query_model.llm_config,
|
||||
planner_max_chat_round=query_model.planner_max_chat_round,
|
||||
browser_nav_max_chat_round=query_model.browser_nav_max_chat_round), media_type="text/event-stream")
|
||||
|
||||
|
||||
def run_task(request: Request, transaction_id: str, command: str, playwright_manager: browserManager.PlaywrightManager, notification_queue: Queue, request_originator: str|None = None, llm_config: dict[str,Any]|None = None, # type: ignore
|
||||
planner_max_chat_round: int = 50, browser_nav_max_chat_round: int = 10):
|
||||
"""
|
||||
Run the task to process the command and generate events.
|
||||
|
||||
Args:
|
||||
request (Request): The request object to detect client disconnect.
|
||||
transaction_id (str): The transaction ID to identify the request.
|
||||
command (str): The command to execute.
|
||||
playwright_manager (PlaywrightManager): The manager handling browser interactions and notifications.
|
||||
notification_queue (Queue): The queue to hold notifications for this request.
|
||||
request_originator (str|None): The originator of the request.
|
||||
llm_config (dict[str,Any]|None): The LLM configuration to use for the agents.
|
||||
planner_max_chat_rounds (int, optional): The maximum number of chat rounds for the planner. Defaults to 50.
|
||||
browser_nav_max_chat_round (int, optional): The maximum number of chat rounds for the browser navigation agent. Defaults to 10.
|
||||
|
||||
Yields:
|
||||
str: JSON-encoded string representing a notification.
|
||||
"""
|
||||
|
||||
async def event_generator():
|
||||
task = asyncio.create_task(process_command(command, playwright_manager, planner_max_chat_round, browser_nav_max_chat_round, llm_config))
|
||||
task_detail = f"transaction_id={transaction_id}, request_originator={request_originator}, command={command}"
|
||||
|
||||
try:
|
||||
while not task.done() or not notification_queue.empty():
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"Client disconnected. Cancelling the task: {task_detail}")
|
||||
task.cancel()
|
||||
break
|
||||
try:
|
||||
notification = notification_queue.get_nowait() # type: ignore
|
||||
notification["transaction_id"] = transaction_id # Include the transaction ID in the notification
|
||||
notification["request_originator"] = request_originator # Include the request originator in the notification
|
||||
yield f"data: {json.dumps(notification)}\n\n" # Using 'data: ' to follow the SSE format
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task was cancelled due to client disconnection. {task_detail}")
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while processing task: {task_detail}. Error: {e}")
|
||||
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task was cancelled due to client disconnection. {task_detail}")
|
||||
await task
|
||||
|
||||
return event_generator()
|
||||
|
||||
|
||||
|
||||
async def process_command(command: str, playwright_manager: browserManager.PlaywrightManager, planner_max_chat_round: int, browser_nav_max_chat_round: int, llm_config:dict[str,Any]|None = None):
|
||||
"""
|
||||
Process the command and send notifications.
|
||||
|
||||
Args:
|
||||
command (str): The command to process.
|
||||
playwright_manager (PlaywrightManager): The manager handling browser interactions and notifications.
|
||||
"""
|
||||
await playwright_manager.go_to_homepage() # Go to the homepage before processing the command
|
||||
current_url = await playwright_manager.get_current_url()
|
||||
await playwright_manager.notify_user("Processing command", MessageType.INFO)
|
||||
|
||||
# Load the configuration using AgentsLLMConfig
|
||||
normalized_llm_config = None
|
||||
if llm_config is None:
|
||||
normalized_llm_config = AgentsLLMConfig()
|
||||
else:
|
||||
normalized_llm_config = AgentsLLMConfig(llm_config=llm_config)
|
||||
logger.info("Applied LLM config received via API.")
|
||||
|
||||
# Retrieve planner agent and browser nav agent configurations
|
||||
planner_agent_config = normalized_llm_config.get_planner_agent_config()
|
||||
browser_nav_agent_config = normalized_llm_config.get_browser_nav_agent_config()
|
||||
|
||||
ag = await AutogenWrapper.create(planner_agent_config, browser_nav_agent_config, planner_max_chat_round=planner_max_chat_round,
|
||||
browser_nav_max_chat_round=browser_nav_max_chat_round)
|
||||
command_exec_result = await ag.process_command(command, current_url) # type: ignore
|
||||
messages=ag.agents_map["planner_agent"].chat_messages
|
||||
messages_str_keys = {str(key): value for key, value in messages.items()} # type: ignore
|
||||
|
||||
with open(os.path.join(SOURCE_LOG_FOLDER_PATH, 'chat_messages.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(messages_str_keys, f, ensure_ascii=False, indent=4)
|
||||
logger.debug("Chat messages saved")
|
||||
|
||||
if is_terminating_message(command_exec_result.summary):
|
||||
await playwright_manager.notify_user("DONE", MessageType.DONE)
|
||||
else:
|
||||
await playwright_manager.notify_user("Max turns reached", MessageType.MAX_TURNS_REACHED)
|
||||
|
||||
|
||||
def register_notification_listener(notification_queue: Queue): # type: ignore
|
||||
"""
|
||||
Register the event generator as a listener in the NotificationManager.
|
||||
"""
|
||||
|
||||
def listener(notification: dict[str, str]) -> None:
|
||||
notification["container_id"] = container_id # Include the container ID (or UUID) in the notification
|
||||
notification_queue.put(notification) # type: ignore
|
||||
|
||||
browser_manager.notification_manager.register_listener(listener)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("**********Application Started**********")
|
||||
uvicorn.run("main:app", host=HOST, port=PORT, workers=WORKERS, reload=IS_DEBUG, log_level="info")
|
801
Agent_E/ae/ui/injectOverlay.js
Normal file
801
Agent_E/ae/ui/injectOverlay.js
Normal file
|
@ -0,0 +1,801 @@
|
|||
let awaitingUserResponse = false; // flag to check if the agent is awaiting user response
|
||||
|
||||
// disabled and enabled styles as injected style element
|
||||
function injectOveralyStyles() {
|
||||
// Create a new style element
|
||||
let style = document.createElement('style');
|
||||
// Set the styles
|
||||
style.textContent = `
|
||||
@import url(https://fonts.googleapis.com/earlyaccess/notosanssc.css);
|
||||
|
||||
::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
border: solid 3px transparent;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background-color: rgba(255, 255, 255, 0.4);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background-color: rgba(255, 255, 255, 0.6);
|
||||
}
|
||||
|
||||
.agente-pre-line {
|
||||
white-space: pre-line; !important;
|
||||
}
|
||||
|
||||
#agente-closebutton{
|
||||
width:30px;
|
||||
height:30px;
|
||||
min-width:30px;
|
||||
min-height:30px;
|
||||
margin-left: auto;
|
||||
color:darkgray;
|
||||
cursor: pointer;
|
||||
background: transparent;
|
||||
transition: transform 0.2s ease;
|
||||
border: None;
|
||||
}
|
||||
#agente-closebutton:hover{
|
||||
transform: scale(1.1);
|
||||
}
|
||||
|
||||
#agente-closebutton:active{
|
||||
transform: scale(0.8);
|
||||
}
|
||||
|
||||
@keyframes agente-gradient-animation {
|
||||
0% {background-position: 100% 0%}
|
||||
100% {background-position: 15% 100%}
|
||||
}
|
||||
|
||||
@keyframes agente-rotate {
|
||||
100% {
|
||||
transform: rotate(1turn);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes automation_highlight_fadeout_animation {
|
||||
0% { border-color: rgba(128, 0, 128, 1); }
|
||||
50% { border-color: rgba(128, 0, 128, 1); }
|
||||
100% { border-color: rgba(128, 0, 128, 0); }
|
||||
}
|
||||
|
||||
.agente-ui-automation-highlight {
|
||||
border-width: 2px !important;
|
||||
border-style: solid !important;
|
||||
animation: automation_highlight_fadeout_animation 5s linear 1 forwards !important;
|
||||
}
|
||||
|
||||
.agente-processing{
|
||||
background: linear-gradient(90deg,
|
||||
rgba(255, 0, 255, 1) 0%, /* Bright Magenta */
|
||||
rgba(0, 191, 255, 1) 100% /* Deep Sky Blue */
|
||||
);
|
||||
background-size: 100% 200%;
|
||||
animation: agente-rotate 1s linear infinite;
|
||||
}
|
||||
|
||||
.agente-init{
|
||||
background: darkgray;
|
||||
box-shadow: rgba(120, 120, 120, 0.3) 0px 0px 20px
|
||||
}
|
||||
|
||||
.agente-done{
|
||||
background: lightgreen;
|
||||
}
|
||||
|
||||
.agente-processingLine {
|
||||
background: linear-gradient(45deg,
|
||||
rgba(255, 0, 0, 1) 0%, /* Red */
|
||||
rgba(255, 127, 0, 1) 25%, /* Orange */
|
||||
rgba(0, 255, 0, 1) 50%, /* Green */
|
||||
rgba(0, 0, 255, 1) 75%, /* Blue */
|
||||
rgba(255, 0, 0, 1) 90%, /* Red */
|
||||
rgba(255, 0, 0, 1) 100% /* Red */
|
||||
);
|
||||
background-size: 500% 100%;
|
||||
animation: agente-gradient-animation 3s linear infinite;
|
||||
}
|
||||
|
||||
.agente-initStateLine{
|
||||
background: lightgray;
|
||||
}
|
||||
|
||||
.agente-doneStateLine{
|
||||
background: lightgreen;
|
||||
}
|
||||
|
||||
.agente-collapsed{
|
||||
cursor: pointer;
|
||||
background-color: rgba(0, 0, 0, 0.1);
|
||||
background-repeat: no-repeat;
|
||||
background-position: center;
|
||||
background-size: cover;
|
||||
width: 6vh;
|
||||
height: 6vh;
|
||||
border-radius: 50%;
|
||||
right: 1.5vw;
|
||||
bottom: 1.5vw;
|
||||
box-shadow: rgba(0, 0, 0, 0.3) 0px 0px 20px
|
||||
}
|
||||
|
||||
.agente-chat-container {
|
||||
margin:1%,1%,1%,1%;
|
||||
width: 30vw;
|
||||
min-width: 350px;
|
||||
height:70vh;
|
||||
bottom: 2vh;
|
||||
position: relative;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
top: 6%;
|
||||
padding: 1%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.agente-chat-input{
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
width: 95%;
|
||||
margin-top:1.5vh;
|
||||
}
|
||||
|
||||
.agente-agent{
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.agente-user{
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
#agente-user-input {
|
||||
flex: 1;
|
||||
padding: 3px 3px;
|
||||
border: transparent;
|
||||
width:100%;
|
||||
resize: none;
|
||||
font-family: 'Noto Sans SC';
|
||||
font-size: 1.6vh;
|
||||
min-font-size: 12px;
|
||||
line-height: 1.5;
|
||||
display: flex;
|
||||
vertical-align: middle;
|
||||
text-align: middle;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border-color: #ccc;
|
||||
background: white;
|
||||
color:black;
|
||||
min-height: calc(1.2em * 2);
|
||||
scrollbar-width: thin;
|
||||
}
|
||||
|
||||
#agente-user-input:focus {
|
||||
outline: none !important;
|
||||
border:0px solid transparent !important;
|
||||
box-shadow: none !important;
|
||||
}
|
||||
|
||||
#agente-send-btn {
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
#agente-send-btn:hover{
|
||||
transform: scale(1.1);
|
||||
}
|
||||
|
||||
.agente-highlight_overlay{
|
||||
box-shadow: 1px 1px 1px 1px rgb(50 50 50 / 40%);
|
||||
border-radius: 16px;
|
||||
border: 1px solid #E1DEE2;
|
||||
bottom:3px;
|
||||
right:5px;
|
||||
background: #FBFAFA;
|
||||
}
|
||||
|
||||
#agente-chat-box {
|
||||
overflow-y: auto;
|
||||
scrollbar-width: thin;
|
||||
height: 90%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap:1%;
|
||||
margin:1% 5%;
|
||||
padding-bottom:1%;
|
||||
margin-top:10%;
|
||||
}
|
||||
|
||||
#agente-overlay {
|
||||
position: fixed;
|
||||
min-width: 50px;
|
||||
min-height: 50px;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
z-index:20000000;
|
||||
scrollbar-color: gray lightgray;
|
||||
margin-bottom: 1%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.agente-input-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin: 1% 3%;
|
||||
padding: 1%;
|
||||
height:20%;
|
||||
background: white;
|
||||
border: 1px solid #E1DEE2;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.agente-chat{
|
||||
width: 80%;
|
||||
color: black;
|
||||
overflow-wrap: break-word;
|
||||
font-family: 'Noto Sans SC';
|
||||
|
||||
}
|
||||
|
||||
.agente-systemMessage{
|
||||
text-align: left;
|
||||
justify-content: flex-start;
|
||||
font-family: 'Noto Sans SC';
|
||||
padding: 2% 4%;
|
||||
font-size: 1.5vh;
|
||||
min-font-size: 12px;
|
||||
min-height: 30px;
|
||||
background: #EEEEEF;
|
||||
line-height: 1.7;
|
||||
border-radius: 10px;
|
||||
width:auto;
|
||||
max-width: 90%;
|
||||
}
|
||||
|
||||
.agente-usertext{
|
||||
text-align: right;
|
||||
justify-content: flex-end;
|
||||
align-items: flex-end;
|
||||
font-family: 'Noto Sans SC';
|
||||
font-size: 1.5vh;
|
||||
min-font-size: 12px;
|
||||
padding: 2% 4%;
|
||||
line-height: 1.7;
|
||||
min-height: 30px;
|
||||
width:auto;
|
||||
background: #ECEBF3;
|
||||
border-radius: 10px;
|
||||
color: black;
|
||||
}
|
||||
|
||||
.agente-agentstep{
|
||||
color: #4B4B4B;
|
||||
}
|
||||
.agente-agentplan{
|
||||
color: #4B4B4B;
|
||||
}
|
||||
.agente-agentanswer{
|
||||
color: black;
|
||||
}
|
||||
|
||||
|
||||
.agente-toggle {
|
||||
-webkit-appearance: none;
|
||||
-moz-appearance: none;
|
||||
appearance: none;
|
||||
margin: 0;
|
||||
display: inline-block;
|
||||
position: relative;
|
||||
border-radius: 50px;
|
||||
overflow: hidden;
|
||||
outline: none;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
background-color: #E1DEE2;
|
||||
transition: background-color ease 0.3s;
|
||||
align-self: center;
|
||||
}
|
||||
.agente-toggle:focus {
|
||||
border: none; !important;
|
||||
outline: none; !important;
|
||||
}
|
||||
.agente-toggle:before {
|
||||
content: "";
|
||||
display: block;
|
||||
position: absolute;
|
||||
z-index: 2;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: #fff;
|
||||
left: 2px;
|
||||
top: 2px;
|
||||
border-radius: 50%;
|
||||
color: #fff;
|
||||
text-shadow: -1px -1px rgba(0,0,0,0.15);
|
||||
white-space: nowrap;
|
||||
box-shadow: 0 1px 2px rgba(0,0,0,0.2);
|
||||
transition: all cubic-bezier(0.3, 1.5, 0.7, 1) 0.3s;
|
||||
}
|
||||
|
||||
.agente-toggle:checked {
|
||||
background-color: #786E96;
|
||||
}
|
||||
|
||||
.agente-toggle:checked:before {
|
||||
left: 20px;
|
||||
}
|
||||
`;
|
||||
// Append the style element to the head of the document
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
let savedSelection = null;
|
||||
let show_details = true;
|
||||
|
||||
|
||||
function showCollapsedOverlay(processing_state = "processing", steps) {
|
||||
show_details = steps;
|
||||
removeOverlay();
|
||||
window.overlay_state_changed(true);
|
||||
let collapsed_agente = document.createElement("div");
|
||||
collapsed_agente.id = "agente-overlay";
|
||||
collapsed_agente.classList.add("agente-collapsed");
|
||||
collapsed_agente.style.backgroundColor = "transparent";
|
||||
collapsed_agente.setAttribute("aria-hidden", "true");
|
||||
collapsed_agente.style.justifyContent = "center";
|
||||
let wrapper = document.createElement("div");
|
||||
wrapper.style.position = "relative";
|
||||
wrapper.style.width = "100%";
|
||||
wrapper.style.height = "100%";
|
||||
wrapper.style.justifyContent = "center";
|
||||
let logodiv= document.createElement("div");
|
||||
logodiv.style.width = "90%";
|
||||
logodiv.style.height = "90%";
|
||||
logodiv.style.left = "5%";
|
||||
logodiv.style.top = "5%";
|
||||
let borderdiv = document.createElement("div");
|
||||
borderdiv.style.width = "100%";
|
||||
borderdiv.style.height = "100%";
|
||||
borderdiv.style.borderRadius = "50%";
|
||||
|
||||
let logo = `<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><rect x="6.5" y="7.5" width="11" height="11" rx="0.5" stroke="#827C8C"/><rect x="-0.5" y="0.5" width="3" height="5" rx="0.5" transform="matrix(-1 0 0 1 6 10)" stroke="#827C8C"/><rect x="-0.5" y="0.5" width="3" height="5" rx="0.5" transform="matrix(-1 0 0 1 20 10)" stroke="#827C8C"/><path d="M12 4V7.5" stroke="#827C8C" stroke-linecap="round"/><rect x="8.5" y="11.5" width="7" height="3" rx="1.5" stroke="#827C8C"/></svg>`;
|
||||
let encodedSvg = encodeURIComponent(logo);
|
||||
let svgUrl = 'data:image/svg+xml;utf8,' + encodedSvg;
|
||||
logodiv.style.backgroundImage = `url("${svgUrl}")`;
|
||||
logodiv.style.backgroundRepeat = "no-repeat";
|
||||
logodiv.style.backgroundSize = "contain";
|
||||
logodiv.style.borderRadius = "50%";
|
||||
logodiv.style.backgroundPosition = "center";
|
||||
logodiv.style.backgroundColor = "white";
|
||||
logodiv.style.alignSelf = "center";
|
||||
borderdiv.style.position = "absolute";
|
||||
borderdiv.style.top = "0";
|
||||
borderdiv.style.left = "0";
|
||||
borderdiv.id="AgentEOverlayBorder";
|
||||
logodiv.style.position = "absolute";
|
||||
logodiv.style.justifySelf = "center";
|
||||
wrapper.appendChild(borderdiv);
|
||||
wrapper.appendChild(logodiv);
|
||||
collapsed_agente.appendChild(wrapper);
|
||||
document.body.appendChild(collapsed_agente);
|
||||
|
||||
updateOverlayState(processing_state, true);
|
||||
|
||||
let element = document.getElementById('agente-overlay');
|
||||
document.getElementById('agente-overlay').addEventListener('mouseover', function () {
|
||||
this.style.transform = 'scale(1.1)';
|
||||
});
|
||||
|
||||
document.getElementById('agente-overlay').addEventListener('mouseout', function () {
|
||||
this.style.transform = 'scale(1)';
|
||||
});
|
||||
document.getElementById('agente-overlay').addEventListener('click', function () {
|
||||
let ui_state = document.getElementById("AgentEOverlayBorder").classList.contains("agente-init") ? "init" : document.getElementById("AgentEOverlayBorder").classList.contains("agente-processing") ? "processing" : "done";
|
||||
showExpandedOverlay(ui_state, show_details);
|
||||
});
|
||||
}
|
||||
|
||||
function removeOverlay() {
|
||||
let element = document.getElementById("agente-overlay");
|
||||
if (element) {
|
||||
element.parentNode.removeChild(element);
|
||||
}
|
||||
}
|
||||
|
||||
function clearOverlayMessages(keep_default=false) {
|
||||
try {
|
||||
let chatBox = document.getElementById('agente-chat-box');
|
||||
if (!chatBox) {
|
||||
return;
|
||||
}
|
||||
while (chatBox.firstChild) {
|
||||
chatBox.removeChild(chatBox.firstChild);
|
||||
}
|
||||
} catch (error) {
|
||||
//No action can be taken at this point. Just ensure subsequent messages are not affected
|
||||
console.error("Error clearing chat box", error);
|
||||
}
|
||||
}
|
||||
|
||||
function updateOverlayState(processing_state, is_collapsed)
|
||||
{
|
||||
if (is_collapsed) {
|
||||
let borderdiv = document.getElementById("AgentEOverlayBorder");
|
||||
if (processing_state === "init"){
|
||||
borderdiv.classList.add("agente-init");
|
||||
borderdiv.classList.remove("agente-processing");
|
||||
borderdiv.classList.remove("agente-done");
|
||||
}
|
||||
else if (processing_state === "processing"){
|
||||
borderdiv.classList.remove("agente-init");
|
||||
borderdiv.classList.add("agente-processing");
|
||||
borderdiv.classList.remove("agente-done");
|
||||
}
|
||||
else if (processing_state === "done"){
|
||||
borderdiv.classList.remove("agente-init");
|
||||
borderdiv.classList.remove("agente-processing");
|
||||
borderdiv.classList.add("agente-done");
|
||||
}
|
||||
} else {
|
||||
let animation = document.getElementById("AgentEExpandedAnimation");
|
||||
if (processing_state === "init"){
|
||||
animation.classList.remove("agente-processingLine");
|
||||
animation.classList.add("agente-initStateLine");
|
||||
animation.classList.remove("agente-doneStateLine");
|
||||
enableOverlay();
|
||||
}
|
||||
else if (processing_state === "processing"){
|
||||
animation.classList.add("agente-processingLine");
|
||||
animation.classList.remove("agente-initStateLine");
|
||||
animation.classList.remove("agente-doneStateLine");
|
||||
disableOverlay();
|
||||
}
|
||||
else if (processing_state === "done"){
|
||||
animation.classList.remove("agente-processingLine");
|
||||
animation.classList.remove("agente-initStateLine");
|
||||
animation.classList.add("agente-doneStateLine");
|
||||
enableOverlay();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function showExpandedOverlay(processing_state = "init", show_steps=true) {
|
||||
ui_state = processing_state;
|
||||
show_details = show_steps;
|
||||
let agente_logo = `<svg width="85" height="12" viewBox="0 0 85 12" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M0 11.8027L3.43562 0.213699H8.35069L11.8027 11.8027H9.3863L8.23562 7.85753H3.53425L2.38356 11.8027H0ZM4.10959 5.86849H7.66027L6.18082 0.80548H5.58904L4.10959 5.86849Z" fill="#6B6673"/><path d="M19.0946 12C15.6096 12 13.7028 9.56712 13.7028 6.09863C13.7028 2.4 15.9055 0 19.4562 0C22.4151 0 24.5685 1.70959 24.9631 4.35616H22.6124C22.3822 2.87671 21.2151 1.9726 19.5713 1.9726C17.3192 1.9726 16.0535 3.58356 16.0535 6.09863C16.0535 8.35068 17.0726 10.011 19.637 10.011C21.7576 10.011 22.974 8.94247 22.974 7.15068H19.374V5.40822H23.9768C24.8151 5.40822 25.2918 5.85205 25.2918 6.69041V11.8027H23.0069V10.7671L23.4672 8.92603H22.8589C22.8754 9.6 22.4973 12 19.0946 12Z" fill="#6B6673"/><path d="M28.7192 11.8027V0.213699H37.3987V2.20274H31.0206V5.04658H36.5768V6.95342H31.0206V9.8137H37.3987V11.8027H28.7192Z" fill="#6B6673"/><path d="M40.533 11.8027V0.213699H45.0536L49.1631 11.211H49.7385L49.3275 9.76438V0.213699H51.6125V11.8027H47.0919L42.9823 0.80548H42.3905L42.8179 2.25205V11.8027H40.533Z" fill="#6B6673"/><path d="M54.4378 0.213699H64.4159V2.20274H60.5693V11.8027H58.2844V2.20274H54.4378V0.213699Z" fill="#6B6673"/><path d="M63.9401 6.6411H72.4551V8.30137H63.9401V6.6411Z" fill="#6B6673"/><path d="M75.3643 11.8027V0.213699H84.0438V2.20274H77.6657V5.04658H83.2219V6.95342H77.6657V9.8137H84.0438V11.8027H75.3643Z" fill="#6B6673"/></svg>`;
|
||||
let close_icon = `<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M5 10L10 10L10 5" stroke="#827C8C"/><path d="M19 14L14 14L14 19" stroke="#827C8C"/><path d="M14 5L14 10L19 10" stroke="#827C8C"/><path d="M10 19L10 14L5 14" stroke="#827C8C"/><path d="M6.35355 5.64645C6.15829 5.45118 5.84171 5.45118 5.64645 5.64645C5.45118 5.84171 5.45118 6.15829 5.64645 6.35355L6.35355 5.64645ZM10.3536 9.64645L6.35355 5.64645L5.64645 6.35355L9.64645 10.3536L10.3536 9.64645Z" fill="#827C8C"/><path d="M17.6464 18.3536C17.8417 18.5488 18.1583 18.5488 18.3536 18.3536C18.5488 18.1583 18.5488 17.8417 18.3536 17.6464L17.6464 18.3536ZM13.6464 14.3536L17.6464 18.3536L18.3536 17.6464L14.3536 13.6464L13.6464 14.3536Z" fill="#827C8C"/><path d="M18.3536 6.35355C18.5488 6.15829 18.5488 5.84171 18.3536 5.64645C18.1583 5.45119 17.8417 5.45119 17.6464 5.64645L18.3536 6.35355ZM14.3536 10.3536L18.3536 6.35355L17.6464 5.64645L13.6464 9.64645L14.3536 10.3536Z" fill="#827C8C"/><path d="M5.64645 17.6464C5.45118 17.8417 5.45118 18.1583 5.64645 18.3536C5.84171 18.5488 6.15829 18.5488 6.35355 18.3536L5.64645 17.6464ZM9.64645 13.6464L5.64645 17.6464L6.35355 18.3536L10.3536 14.3536L9.64645 13.6464Z" fill="#827C8C"/></svg>`;
|
||||
let icon = `<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><rect x="6.5" y="7.5" width="11" height="11" rx="0.5" stroke="#827C8C"/><rect x="-0.5" y="0.5" width="3" height="5" rx="0.5" transform="matrix(-1 0 0 1 6 10)" stroke="#827C8C"/><rect x="-0.5" y="0.5" width="3" height="5" rx="0.5" transform="matrix(-1 0 0 1 20 10)" stroke="#827C8C"/><path d="M12 4V7.5" stroke="#827C8C" stroke-linecap="round"/><rect x="8.5" y="11.5" width="7" height="3" rx="1.5" stroke="#827C8C"/></svg>`;
|
||||
removeOverlay();
|
||||
window.overlay_state_changed(false);
|
||||
let newDiv = document.createElement("div");
|
||||
newDiv.id = "agente-overlay";
|
||||
newDiv.classList.add("agente-highlight_overlay");
|
||||
newDiv.setAttribute("aria-hidden", "true");
|
||||
newDiv.setAttribute("tabindex", "0");
|
||||
|
||||
let header = document.createElement("div");
|
||||
header.style.display = "flex";
|
||||
header.style.flexDirection = "row";
|
||||
header.style.margin = "4%";
|
||||
|
||||
let logoIcon= document.createElement("div");
|
||||
logoIcon.style.width = "25px";
|
||||
logoIcon.style.height = "25px";
|
||||
logoIcon.style.backgroundImage = `url('data:image/svg+xml;utf8,${encodeURIComponent(icon)}')`;
|
||||
logoIcon.style.backgroundRepeat = "no-repeat";
|
||||
logoIcon.style.backgroundSize = "contain";
|
||||
logoIcon.style.backgroundPosition = "bottom";
|
||||
logoIcon.style.order = 1;
|
||||
logoIcon.style.alignSelf = "flex-end";
|
||||
logoIcon.style.marginRight = "1%";
|
||||
|
||||
let logoDiv = document.createElement("div");
|
||||
logoDiv.style.width = "100px";
|
||||
logoDiv.style.height = "25px";
|
||||
logoDiv.style.backgroundImage = `url('data:image/svg+xml;utf8,${encodeURIComponent(agente_logo)}')`;
|
||||
logoDiv.style.backgroundRepeat = "no-repeat";
|
||||
logoDiv.style.backgroundSize = "contain";
|
||||
logoDiv.style.backgroundPosition = "bottom";
|
||||
// Style the logoDiv and button
|
||||
logoDiv.style.order = 1;
|
||||
|
||||
|
||||
let closeButton = document.createElement("button");
|
||||
closeButton.id = "agente-closebutton";
|
||||
closeButton.style.backgroundImage = `url('data:image/svg+xml;utf8,${encodeURIComponent(close_icon)}')`;
|
||||
closeButton.style.backgroundRepeat = "no-repeat";
|
||||
closeButton.style.backgroundSize = "contain";
|
||||
closeButton.style.backgroundPosition = "bottom";
|
||||
closeButton.onclick = function () {
|
||||
let ui_state = document.getElementById("AgentEExpandedAnimation").classList.contains("agente-initStateLine") ? "init" : document.getElementById("AgentEExpandedAnimation").classList.contains("agente-processingLine") ? "processing" : "done";
|
||||
showCollapsedOverlay(ui_state, show_details);
|
||||
};
|
||||
closeButton.style.order = 3;
|
||||
header.appendChild(logoIcon);
|
||||
header.appendChild(logoDiv);
|
||||
let animation = document.createElement("div");
|
||||
animation.id = "AgentEExpandedAnimation";
|
||||
animation.style.height = "2px";
|
||||
animation.style.width = "100%";
|
||||
|
||||
header.appendChild(closeButton);
|
||||
// Append the close button to the newDiv
|
||||
newDiv.appendChild(header);
|
||||
|
||||
|
||||
newDiv.appendChild(animation);
|
||||
let chatContainer = document.createElement("div");
|
||||
chatContainer.className = "agente-chat-container";
|
||||
|
||||
let chatBox = document.createElement("div");
|
||||
chatBox.id = "agente-chat-box";
|
||||
|
||||
let chatInput = document.createElement("div");
|
||||
chatInput.className = "agente-chat-input";
|
||||
chatBox.appendChild(chatInput);
|
||||
|
||||
let inputContainer = document.createElement("div");
|
||||
inputContainer.className = "agente-input-container";
|
||||
inputContainer.id = "agente-input-container";
|
||||
let userInput = document.createElement("textarea");
|
||||
userInput.id = "agente-user-input";
|
||||
userInput.placeholder = "What can I help you solve today?";
|
||||
userInput.addEventListener('input', function(event) {
|
||||
let text = event.target.value;
|
||||
if (text.trim() == "") {
|
||||
let button_disabled_svg =`<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg"><rect width="40" height="40" rx="4" fill="#EEEEEF"/><path d="M15 20H25" stroke="#AEA9B4" stroke-width="1.5"/><path d="M20 15L25 20L20 25" stroke="#AEA9B4" stroke-width="1.5"/></svg>`;
|
||||
let sendBtn = document.getElementById('agente-send-btn');
|
||||
sendBtn.style.backgroundImage = `url('data:image/svg+xml;utf8,${encodeURIComponent(button_disabled_svg)}')`;
|
||||
}
|
||||
else{
|
||||
let button_enabled_svg= `<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg"><rect width="40" height="40" rx="4" fill="#252539"/><path d="M15 20H25" stroke="white" stroke-width="1.5"/><path d="M20 15L25 20L20 25" stroke="white" stroke-width="1.5"/></svg>`;
|
||||
let sendBtn = document.getElementById('agente-send-btn');
|
||||
sendBtn.style.backgroundImage = `url('data:image/svg+xml;utf8,${encodeURIComponent(button_enabled_svg)}')`;
|
||||
}
|
||||
});
|
||||
let userinput_footer = document.createElement("div");
|
||||
userinput_footer.style.display = "flex";
|
||||
userinput_footer.style.flexDirection = "row";
|
||||
userinput_footer.style.justifyContent = "space-between";
|
||||
userinput_footer.style.alignItems = "center";
|
||||
userinput_footer.style.height = "40%";
|
||||
userinput_footer.style.margin = "2% 1%";
|
||||
userinput_footer.id="userinput_section"
|
||||
|
||||
let toggleLabel = document.createElement("label"); // Create a new label element
|
||||
toggleLabel.textContent = "Show Details"; // Set the text content of the label
|
||||
toggleLabel.style.color = "#6B6673"; // Set the color of the label
|
||||
toggleLabel.style.fontFamily = "Noto Sans SC"; // Set the font of the label
|
||||
toggleLabel.style.fontSize = "14px"; // Set the font size of the label
|
||||
toggleLabel.style.fontWeight = "400"; // Set the font weight of the label
|
||||
toggleLabel.style.margin = "0px"; // Add some margin to the right of the label
|
||||
toggleLabel.style.marginRight = "10px"; // Add some margin to the right of the label
|
||||
|
||||
let toggleSwitch = document.createElement("input");
|
||||
|
||||
toggleSwitch.type = "checkbox";
|
||||
toggleSwitch.className = "agente-toggle";
|
||||
toggleSwitch.style.width = "44px";
|
||||
toggleSwitch.style.height = "24px";
|
||||
toggleSwitch.style.margin = "0px";
|
||||
|
||||
if (show_details){
|
||||
toggleSwitch.checked = true;
|
||||
}
|
||||
else{
|
||||
toggleSwitch.checked = false;
|
||||
}
|
||||
|
||||
toggleSwitch.addEventListener('change', function() {
|
||||
if(this.checked) {
|
||||
show_details = true;
|
||||
window.show_steps_state_changed(true)
|
||||
} else {
|
||||
show_details = false;
|
||||
window.show_steps_state_changed(false)
|
||||
}
|
||||
});
|
||||
|
||||
let sendicon =`<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg"><rect width="40" height="40" rx="4" fill="#EEEEEF"/><path d="M15 20H25" stroke="#AEA9B4" stroke-width="1.5"/><path d="M20 15L25 20L20 25" stroke="#AEA9B4" stroke-width="1.5"/></svg>`;
|
||||
let sendBtn = document.createElement("div");
|
||||
sendBtn.id = "agente-send-btn";
|
||||
sendBtn.style.backgroundImage = `url('data:image/svg+xml;utf8,${encodeURIComponent(sendicon)}')`;
|
||||
sendBtn.style.backgroundRepeat = "no-repeat";
|
||||
sendBtn.style.backgroundSize = "contain";
|
||||
sendBtn.style.backgroundPosition = "right";
|
||||
sendBtn.style.width = "8%";
|
||||
sendBtn.style.height = "100%";
|
||||
sendBtn.style.marginLeft = "auto";
|
||||
|
||||
userinput_footer.appendChild(toggleLabel); // Add the label to the div
|
||||
userinput_footer.appendChild(toggleSwitch);
|
||||
userinput_footer.appendChild(sendBtn);
|
||||
|
||||
inputContainer.appendChild(userInput);
|
||||
inputContainer.appendChild(userinput_footer);
|
||||
|
||||
chatContainer.appendChild(chatBox);
|
||||
chatContainer.appendChild(inputContainer);
|
||||
|
||||
newDiv.appendChild(chatContainer);
|
||||
|
||||
let disclaimer = document.createElement("p");
|
||||
disclaimer.style.fontFamily = "Noto Sans SC";
|
||||
disclaimer.style.fontSize = "12px";
|
||||
disclaimer.style.color = "#6B6673";
|
||||
disclaimer.style.alignSelf = "center";
|
||||
disclaimer.style.position = "absolute";
|
||||
disclaimer.style.bottom = "0%";
|
||||
disclaimer.style.margin = "0% 0% 1% 0%";
|
||||
disclaimer.textContent = "Agent-E may make mistakes. Verify key info.";
|
||||
|
||||
newDiv.appendChild(disclaimer);
|
||||
|
||||
document.body.appendChild(newDiv);
|
||||
updateOverlayState(processing_state, false);
|
||||
document.getElementById('agente-send-btn').addEventListener('click', function () {
|
||||
let task = document.getElementById('agente-user-input').value
|
||||
let task_trimmed = task.trim();
|
||||
if (task_trimmed && !isDisabled() && task_trimmed.length > 0) {
|
||||
if (awaitingUserResponse) {
|
||||
addUserMessage(task);
|
||||
document.getElementById('agente-user-input').value = "";
|
||||
} else {
|
||||
clearOverlayMessages();
|
||||
addUserMessage(task);
|
||||
disableOverlay();
|
||||
window.process_task(task)
|
||||
document.getElementById('agente-user-input').value = "";
|
||||
}
|
||||
}
|
||||
else {
|
||||
console.log("Empty message no task to send");
|
||||
}
|
||||
});
|
||||
|
||||
userInput.addEventListener('focus', function() {
|
||||
if (window.getSelection().rangeCount > 0) {
|
||||
let selectedText = window.getSelection().toString();
|
||||
if (selectedText) {
|
||||
document.getElementById('agente-user-input').value = selectedText + '\n';
|
||||
setTimeout(function() {
|
||||
userInput.selectionStart = userInput.selectionEnd = userInput.value.length;
|
||||
userInput.scrollTop = userInput.scrollHeight;
|
||||
}, 0);
|
||||
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
userInput.addEventListener('blur', function() {
|
||||
if (savedSelection) {
|
||||
let selection = window.getSelection();
|
||||
selection.removeAllRanges();
|
||||
selection.addRange(savedSelection);
|
||||
}
|
||||
});
|
||||
|
||||
document.getElementById('agente-user-input').addEventListener('keydown', function (event) {
|
||||
// Check if the pressed key is the Enter key
|
||||
if (event.key === "Enter") {
|
||||
event.preventDefault();
|
||||
|
||||
let targetElement = document.getElementById('agente-send-btn');
|
||||
|
||||
// Create a new click event
|
||||
let clickEvent = new MouseEvent('click', {
|
||||
bubbles: true,
|
||||
cancelable: true
|
||||
});
|
||||
|
||||
// Dispatch the click event on the send button
|
||||
targetElement.dispatchEvent(clickEvent);
|
||||
}
|
||||
});
|
||||
focusOnOverlayInput();
|
||||
}
|
||||
|
||||
|
||||
function focusOnOverlayInput() {
|
||||
document.getElementById('agente-user-input').focus();
|
||||
}
|
||||
|
||||
function addMessage(message, sender, message_type = "plan") {
|
||||
let newDiv = document.createElement("div");
|
||||
newDiv.classList.add("agente-chat-input");
|
||||
let chatDiv = document.createElement("div");
|
||||
chatDiv.classList.add("agente-chat");
|
||||
|
||||
let parsedMessage = message;
|
||||
|
||||
try {
|
||||
parsedMessage = JSON.parse(message);
|
||||
} catch (e) {
|
||||
console.log("Message is not in JSON format, using original message.");
|
||||
}
|
||||
|
||||
// Customize based on the sender
|
||||
if (sender === "system") {
|
||||
newDiv.classList.add("agente-agent");
|
||||
chatDiv.classList.add("agente-systemMessage", "agente-pre-line");
|
||||
if (message_type === "step") {
|
||||
chatDiv.classList.add("agente-agentstep");
|
||||
}
|
||||
else if (message_type === "plan" || message_type === "question") {
|
||||
chatDiv.classList.add("agente-agentplan");
|
||||
}
|
||||
|
||||
else if (message_type === "answer") {
|
||||
chatDiv.classList.add("agente-agentanswer");
|
||||
}
|
||||
if ((message_type === "info" && message.includes("Task Completed")) || message_type==="question") {
|
||||
enableOverlay();
|
||||
}
|
||||
chatDiv.textContent = parsedMessage;
|
||||
} else if (sender === "user") {
|
||||
newDiv.classList.add("agente-user")
|
||||
chatDiv.classList.add("agente-usertext", "agente-pre-line");
|
||||
chatDiv.textContent = parsedMessage;
|
||||
}
|
||||
newDiv.appendChild(chatDiv);
|
||||
let chatBox = document.getElementById('agente-chat-box');
|
||||
chatBox.appendChild(newDiv);
|
||||
chatBox.scrollTop = chatBox.scrollHeight;
|
||||
newDiv.scrollIntoView({ behavior: 'instant' });
|
||||
|
||||
if (sender === "user" && awaitingUserResponse) {
|
||||
awaitingUserResponse = false;
|
||||
// Notify the server that the user has responded to the agent's prompt
|
||||
window.user_response(message);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
function addSystemMessage(message, is_awaiting_user_response = false, message_type = "plan") {
|
||||
// Function to actually add the message
|
||||
function executeAddMessage() {
|
||||
awaitingUserResponse = is_awaiting_user_response;
|
||||
addMessage(message, "system", message_type);
|
||||
}
|
||||
requestAnimationFrame(executeAddMessage);
|
||||
}
|
||||
|
||||
function addUserMessage(message) {
|
||||
addMessage(message, "user");
|
||||
}
|
||||
|
||||
function disableOverlay() {
|
||||
let input_field= document.getElementById("agente-user-input");
|
||||
if(input_field){
|
||||
input_field.placeholder = "Processing...";
|
||||
}
|
||||
}
|
||||
|
||||
function isDisabled() {
|
||||
let input_field= document.getElementById("agente-user-input");
|
||||
if(input_field){
|
||||
return input_field.placeholder === "Processing...";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function enableOverlay() {
|
||||
let input_field= document.getElementById("agente-user-input");
|
||||
if(input_field){
|
||||
input_field.placeholder = "What can I help you solve today?";
|
||||
}
|
||||
}
|
||||
|
||||
function commandExecutionCompleted() {
|
||||
console.log("Command execution completed");
|
||||
}
|
||||
|
||||
injectOveralyStyles();
|
11
Agent_E/ae/user_preferences/user_preferences.txt
Normal file
11
Agent_E/ae/user_preferences/user_preferences.txt
Normal file
|
@ -0,0 +1,11 @@
|
|||
Personal Info:
|
||||
First Name: John
|
||||
Last Name: Doe
|
||||
Date of birth: 10/10/2010
|
||||
Occupation: Software Engineer
|
||||
Address: 49 Featherstone Street, LONDON, EC1Y 8SY, UNITED KINGDOM
|
||||
Email: myemail@gmail.com
|
||||
Phone Number: 123-456-7890
|
||||
Here are some of my preferences:
|
||||
Favorite news source: www.bbc.com
|
||||
Favorite flight booking site to use with every flight related query: https://www.google.com/travel/flights
|
0
Agent_E/ae/utils/__init__.py
Normal file
0
Agent_E/ae/utils/__init__.py
Normal file
52
Agent_E/ae/utils/anthropic_llm_helper.py
Normal file
52
Agent_E/ae/utils/anthropic_llm_helper.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
class AnthropicLLMHelper:
|
||||
def __init__(self):
|
||||
load_dotenv()
|
||||
self.client = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||
|
||||
async def get_chat_completion_response_async(self, system_msg:str, user_msgs:list[str], model_name:str="claude-3-opus-20240229", temperature:float=0.1, max_tokens:int=256, top_p:int=1, top_k: int=1) -> str:
|
||||
formatted_user_msgs: list[dict[str, str]] = []
|
||||
for user_msg in user_msgs:
|
||||
formatted_user_msgs.append({"type": "text", "text": user_msg})
|
||||
|
||||
try:
|
||||
message = await self.client.messages.create(
|
||||
model=model_name,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
system=system_msg,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": formatted_user_msgs # type: ignore
|
||||
|
||||
}
|
||||
]
|
||||
)
|
||||
print(message)
|
||||
return message.content[0].text
|
||||
except anthropic.APIConnectionError as e:
|
||||
print("The server could not be reached")
|
||||
print(e.__cause__) # an underlying Exception, likely raised within httpx.
|
||||
raise Exception(f"Calling {model_name} LLM failed. The server could not be reached. Error: {e}") # noqa: B904
|
||||
except anthropic.RateLimitError as e:
|
||||
print("A 429 status code was received; we should back off a bit.")
|
||||
raise Exception(f"Calling {model_name} LLM failed. Rate limit error. Error: {e}") # noqa: B904
|
||||
except anthropic.APIStatusError as e:
|
||||
print(e.status_code)
|
||||
print(e.response)
|
||||
raise Exception(f"Calling {model_name} LLM failed. Error: {e}") # noqa: B904
|
||||
|
||||
# async def main():
|
||||
# from ae.core.prompts import LLM_PROMPTS
|
||||
# helper = AnthropicLLMHelper()
|
||||
# response = await helper.get_chat_completion_response_async(LLM_PROMPTS["SKILLS_HARVESTING_PROMPT"], ["What is the weather like today?"], temperature=0, max_tokens=4000)
|
||||
# print("*******\nResponse: ", response, "\n*******\n")
|
||||
|
||||
# asyncio.run(main())
|
85
Agent_E/ae/utils/autogen_sequential_function_call.py
Normal file
85
Agent_E/ae/utils/autogen_sequential_function_call.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from autogen import Agent # type: ignore
|
||||
from autogen import UserProxyAgent # type: ignore
|
||||
|
||||
|
||||
class UserProxyAgent_SequentialFunctionExecution(UserProxyAgent):
|
||||
def __init__(self, *args, **kwargs): # type: ignore
|
||||
super().__init__(*args, **kwargs) # type: ignore
|
||||
#position = 2 allows termination check to be called earlier, this helps detect loops.
|
||||
self.register_reply(Agent, UserProxyAgent_SequentialFunctionExecution.sequential_generate_tool_calls_reply, position=2) # type: ignore
|
||||
|
||||
|
||||
def sequential_generate_tool_calls_reply( # type: ignore
|
||||
self,
|
||||
messages: list[dict] | None = None, # type: ignore
|
||||
sender: Agent | None = None,
|
||||
config: Any | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Generate a reply using tool call."""
|
||||
if config is None:
|
||||
config = self
|
||||
if messages is None:
|
||||
messages = self._oai_messages[sender] # type: ignore
|
||||
message = messages[-1] # type: ignore
|
||||
tool_returns = []
|
||||
skip_flag:bool = False
|
||||
for tool_call in message.get("tool_calls", []): # type: ignore
|
||||
function_call = tool_call.get("function", {}) # type: ignore
|
||||
func = self._function_map.get(function_call.get("name", None), None) # type: ignore
|
||||
func_return = None
|
||||
if inspect.iscoroutinefunction(func): # type: ignore
|
||||
try:
|
||||
# get the running loop if it was already created
|
||||
loop = asyncio.get_running_loop()
|
||||
close_loop = False
|
||||
except RuntimeError:
|
||||
# create a loop if there is no running loop
|
||||
loop = asyncio.new_event_loop()
|
||||
close_loop = True
|
||||
if (not skip_flag):
|
||||
_, func_return = loop.run_until_complete(self.a_execute_function(function_call)) # type: ignore
|
||||
if close_loop:
|
||||
loop.close()
|
||||
else:
|
||||
if (not skip_flag):
|
||||
_, func_return = self.execute_function(function_call) # type: ignore
|
||||
if func_return is None: # type: ignore
|
||||
if skip_flag:
|
||||
content = "VERY IMPORTANT: This function could not be executed since previous function resulted in a Webpage change. You must get all_fields DOM and repeat the function if needed."
|
||||
else:
|
||||
content = ""
|
||||
else:
|
||||
content = func_return.get("content", "") # type: ignore
|
||||
|
||||
if content is None:
|
||||
content = ""
|
||||
|
||||
if ("as a consequence of this action" in content.lower()): # type: ignore
|
||||
skip_flag = True
|
||||
|
||||
tool_call_id = tool_call.get("id", None) # type: ignore
|
||||
if tool_call_id is not None:
|
||||
tool_call_response = { # type: ignore
|
||||
"tool_call_id": tool_call_id,
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
}
|
||||
else:
|
||||
tool_call_response = { # type: ignore
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
}
|
||||
tool_returns.append(tool_call_response) # type: ignore
|
||||
|
||||
if tool_returns:
|
||||
return True, {
|
||||
"role": "tool",
|
||||
"tool_responses": tool_returns,
|
||||
"content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]), # type: ignore
|
||||
}
|
||||
return False, None
|
34
Agent_E/ae/utils/cli_helper.py
Normal file
34
Agent_E/ae/utils/cli_helper.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import asyncio
|
||||
from asyncio import Future
|
||||
|
||||
|
||||
def async_input(prompt: str) -> Future: # type: ignore
|
||||
"""
|
||||
Display a prompt to the user and wait for input in an asynchronous manner.
|
||||
|
||||
Parameters:
|
||||
- prompt: The message to display to the user.
|
||||
|
||||
Returns:
|
||||
- A Future object that will be fulfilled with the user's input.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_in_executor(None, input, prompt)
|
||||
|
||||
|
||||
async def answer_questions_over_cli(questions: list[str]) -> dict[str, str]:
|
||||
"""
|
||||
Asks a question over the command line and awaits the user's response.
|
||||
|
||||
Parameters:
|
||||
- questions: A list of questions to ask the user, e.g., ["What is your favorite site?", "What do you want to search for?"].
|
||||
|
||||
Returns:
|
||||
- A dictionary where each key is a question and each value is the user's response.
|
||||
"""
|
||||
answers: dict[str, str] = {}
|
||||
print("*********************************")
|
||||
for question in questions:
|
||||
answers[question] = await async_input("Question: "+str(question)+" : ")
|
||||
print("*********************************")
|
||||
return answers
|
46
Agent_E/ae/utils/detect_llm_loops.py
Normal file
46
Agent_E/ae/utils/detect_llm_loops.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from typing import Any
|
||||
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
def is_agent_stuck_in_loop(messages: list[dict[str, Any]]) -> bool:
|
||||
"""
|
||||
Detects loops in the agent's responses by iterating over the last N responses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages : list[dict[str, Any]]
|
||||
A list of dictionaries representing the agent's messages.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if a loop is detected, False otherwise.
|
||||
"""
|
||||
number_of_turns_to_check_for_loops: int = 6
|
||||
number_of_rounds_to_check_for_loops: int = number_of_turns_to_check_for_loops // 2 #integer division since we are checking for pairs of messages and can't have fractions
|
||||
# Detect any loops by checking the last number_of_rounds_to_check_for_loops tool responses and their corresponding tool calls
|
||||
if len(messages) > number_of_turns_to_check_for_loops:
|
||||
last_six_items = messages[-number_of_turns_to_check_for_loops:]
|
||||
logger.debug(f"More than {number_of_turns_to_check_for_loops} messages in the conversation. Checking for loops..")
|
||||
# Filter items by role
|
||||
tool_calls = [item for item in last_six_items if item.get("role") == "assistant"]
|
||||
|
||||
# Check if function attributes are the same for tool items
|
||||
if tool_calls:
|
||||
tool_functions = [item.get("tool_calls", [{}])[0].get("function") for item in tool_calls]
|
||||
logger.debug(f"Last {number_of_rounds_to_check_for_loops} tool calls: {tool_functions}")
|
||||
if all(func == tool_functions[0] for func in tool_functions):
|
||||
logger.debug(f"Last {number_of_rounds_to_check_for_loops} tool calls are identical. Checking Tool responses..")
|
||||
# Check if content attributes are the same for assistant items
|
||||
tool_responses = [item for item in last_six_items if item.get("role") == "tool"]
|
||||
|
||||
if tool_responses:
|
||||
assistant_contents = [item.get("content") for item in tool_responses]
|
||||
logger.debug(f"Last N tool responses: {assistant_contents}")
|
||||
if all(content == assistant_contents[0] for content in assistant_contents):
|
||||
logger.debug(f"Last {number_of_rounds_to_check_for_loops} tool responses are identical. Terminating")
|
||||
logger.info("Terminating browser executor since a loop was detected...")
|
||||
return True
|
||||
|
||||
return False
|
45
Agent_E/ae/utils/dom_helper.py
Normal file
45
Agent_E/ae/utils/dom_helper.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import asyncio
|
||||
|
||||
from playwright.async_api import ElementHandle
|
||||
from playwright.async_api import Page
|
||||
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
async def wait_for_non_loading_dom_state(page: Page, max_wait_millis: int):
|
||||
max_wait_seconds = max_wait_millis / 1000
|
||||
end_time = asyncio.get_event_loop().time() + max_wait_seconds
|
||||
while asyncio.get_event_loop().time() < end_time:
|
||||
dom_state = await page.evaluate("document.readyState")
|
||||
if dom_state != "loading":
|
||||
logger.debug(f"DOM state is not 'loading': {dom_state}")
|
||||
break # Exit the loop if the DOM state is not 'loading'
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
async def get_element_outer_html(element: ElementHandle, page: Page, element_tag_name: str|None = None) -> str:
|
||||
"""
|
||||
Constructs the opening tag of an HTML element along with its attributes.
|
||||
|
||||
Args:
|
||||
element (ElementHandle): The element to retrieve the opening tag for.
|
||||
page (Page): The page object associated with the element.
|
||||
element_tag_name (str, optional): The tag name of the element. Defaults to None. If not passed, it will be retrieved from the element.
|
||||
|
||||
Returns:
|
||||
str: The opening tag of the HTML element, including a select set of attributes.
|
||||
"""
|
||||
tag_name: str = element_tag_name if element_tag_name else await page.evaluate("element => element.tagName.toLowerCase()", element)
|
||||
|
||||
attributes_of_interest: list[str] = ['id', 'name', 'aria-label', 'placeholder', 'href', 'src', 'aria-autocomplete', 'role', 'type',
|
||||
'data-testid', 'value', 'selected', 'aria-labelledby', 'aria-describedby', 'aria-haspopup']
|
||||
opening_tag: str = f'<{tag_name}'
|
||||
|
||||
for attr in attributes_of_interest:
|
||||
value: str = await element.get_attribute(attr) # type: ignore
|
||||
if value:
|
||||
opening_tag += f' {attr}="{value}"'
|
||||
opening_tag += '>'
|
||||
|
||||
return opening_tag
|
88
Agent_E/ae/utils/dom_mutation_observer.py
Normal file
88
Agent_E/ae/utils/dom_mutation_observer.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable # noqa: UP035
|
||||
|
||||
from playwright.async_api import Page
|
||||
|
||||
# Create an event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
DOM_change_callback: list[Callable[[str], None]] = []
|
||||
|
||||
def subscribe(callback: Callable[[str], None]) -> None:
|
||||
DOM_change_callback.append(callback)
|
||||
|
||||
def unsubscribe(callback: Callable[[str], None]) -> None:
|
||||
DOM_change_callback.remove(callback)
|
||||
|
||||
|
||||
async def add_mutation_observer(page:Page):
|
||||
"""
|
||||
Adds a mutation observer to the page to detect changes in the DOM.
|
||||
When changes are detected, the observer calls the dom_mutation_change_detected function in the browser context.
|
||||
This changes can be detected by subscribing to the dom_mutation_change_detected function by individual skills.
|
||||
|
||||
Current implementation only detects when a new node is added to the DOM.
|
||||
However, in many cases, the change could be a change in the style or class of an existing node (e.g. toggle visibility of a hidden node).
|
||||
"""
|
||||
|
||||
await page.evaluate("""
|
||||
console.log('Adding a mutation observer for DOM changes');
|
||||
new MutationObserver((mutationsList, observer) => {
|
||||
let changes_detected = [];
|
||||
for(let mutation of mutationsList) {
|
||||
if (mutation.type === 'childList') {
|
||||
let allAddedNodes=mutation.addedNodes;
|
||||
for(let node of allAddedNodes) {
|
||||
if(node.tagName && !['SCRIPT', 'NOSCRIPT', 'STYLE'].includes(node.tagName) && !node.closest('#agentDriveAutoOverlay')) {
|
||||
let visibility=true;
|
||||
let content = node.innerText.trim();
|
||||
if(visibility && node.innerText.trim()){
|
||||
if(content) {
|
||||
changes_detected.push({tag: node.tagName, content: content});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (mutation.type === 'characterData') {
|
||||
let node = mutation.target;
|
||||
if(node.parentNode && !['SCRIPT', 'NOSCRIPT', 'STYLE'].includes(node.parentNode.tagName) && !node.parentNode.closest('#agentDriveAutoOverlay')) {
|
||||
let visibility=true;
|
||||
let content = node.data.trim();
|
||||
if(visibility && content && window.getComputedStyle(node.parentNode).display !== 'none'){
|
||||
if(content && !changes_detected.some(change => change.content.includes(content))) {
|
||||
changes_detected.push({tag: node.parentNode.tagName, content: content});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if(changes_detected.length > 0) {
|
||||
window.dom_mutation_change_detected(JSON.stringify(changes_detected));
|
||||
}
|
||||
}).observe(document, {subtree: true, childList: true, characterData: true});
|
||||
""")
|
||||
|
||||
|
||||
async def handle_navigation_for_mutation_observer(page:Page):
|
||||
await add_mutation_observer(page)
|
||||
|
||||
async def dom_mutation_change_detected(changes_detected: str):
|
||||
"""
|
||||
Detects changes in the DOM (new nodes added) and emits the event to all subscribed callbacks.
|
||||
The changes_detected is a string in JSON formatt containing the tag and content of the new nodes added to the DOM.
|
||||
|
||||
e.g. The following will be detected when autocomplete recommendations show up when one types Nelson Mandela on google search
|
||||
[{'tag': 'SPAN', 'content': 'nelson mandela wikipedia'}, {'tag': 'SPAN', 'content': 'nelson mandela movies'}]
|
||||
"""
|
||||
changes_detected = json.loads(changes_detected.replace('\t', '').replace('\n', ''))
|
||||
if len(changes_detected) > 0:
|
||||
# Emit the event to all subscribed callbacks
|
||||
for callback in DOM_change_callback:
|
||||
# If the callback is a coroutine function
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(changes_detected)
|
||||
# If the callback is a regular function
|
||||
else:
|
||||
callback(changes_detected)
|
56
Agent_E/ae/utils/formatting_helper.py
Normal file
56
Agent_E/ae/utils/formatting_helper.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def str_to_bool(s: str | bool) -> bool:
|
||||
"""
|
||||
Convert a string representation of truth to True or False.
|
||||
|
||||
Parameters:
|
||||
s (str | bool): The string to convert, or a boolean.
|
||||
|
||||
Returns:
|
||||
bool: True if the string represents a truth value, False otherwise.
|
||||
"""
|
||||
if isinstance(s, bool):
|
||||
return s
|
||||
return s.lower() in ['true', '1', 't', 'y', 'yes']
|
||||
|
||||
def str_to_json(s: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Convert a string representation of a JSON object to a dictionary.
|
||||
|
||||
Parameters:
|
||||
s (str): The string to convert.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] | None: The dictionary representation of the JSON object. If the parsing fails, returns None.
|
||||
"""
|
||||
s_fixed = re.sub(r'(?<!\\)\n', '\\n', s) #escape newline characters as long as they are not already escaped
|
||||
|
||||
# Now you can safely load it using json.loads
|
||||
try:
|
||||
obj = json.loads(s_fixed)
|
||||
return obj
|
||||
except json.JSONDecodeError as e:
|
||||
return None
|
||||
|
||||
def is_terminating_message(message: str) -> bool:
|
||||
"""
|
||||
Check if a message is a terminating message.
|
||||
|
||||
Parameters:
|
||||
message (str): The message to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the message is a terminating message, False otherwise.
|
||||
"""
|
||||
message_as_json = str_to_json(message)
|
||||
if message_as_json is None:
|
||||
if message.find('"terminate": "yes"') != -1:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return message_as_json.get("terminate") == "yes"
|
77
Agent_E/ae/utils/gemini_llm_helper.py
Normal file
77
Agent_E/ae/utils/gemini_llm_helper.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import google.generativeai as genai # type: ignore
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
GCP_BLOCK_NONE_SAFETY_SETTINGS: list[dict[str, str]] = [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_NONE"
|
||||
},
|
||||
]
|
||||
|
||||
# Pre-compile the regular expression pattern for removing json markers from LLM response
|
||||
llm_json_or_python_begin_response_pattern = re.compile(r"^```(python|json)?\n?")
|
||||
llm_end_response_pattern = re.compile(r"```$")
|
||||
|
||||
class GeminiLLMHelper:
|
||||
def __init__(self):
|
||||
load_dotenv()
|
||||
genai.configure(api_key=os.environ.get("GEMINI_API_KEY")) # type: ignore
|
||||
|
||||
def process_llm_response(self, response: str):
|
||||
if response:
|
||||
# Use the compiled regex to replace the patterns with an empty string
|
||||
response = llm_json_or_python_begin_response_pattern.sub("", response)
|
||||
response = llm_end_response_pattern.sub("", response)
|
||||
return response
|
||||
|
||||
async def get_chat_completion_response_async(self, system_msg:str, user_msgs:list[str], model_name:str="gemini-1.5-pro-latest", temperature:float=0.1,
|
||||
max_tokens:int=256, top_p:int=1, top_k: int=1, safety_settings:list[dict[str, str]]=GCP_BLOCK_NONE_SAFETY_SETTINGS) -> str|None:
|
||||
formatted_msgs: list[dict[str, Any]] = [{"role": "user", "parts": [system_msg]}]
|
||||
user_msgs_parts: list[str] = []
|
||||
for user_msg in user_msgs:
|
||||
user_msgs_parts.append(user_msg)
|
||||
|
||||
formatted_msgs.append({"role": "user", "parts": user_msgs_parts})
|
||||
response = None
|
||||
try:
|
||||
model = genai.GenerativeModel(model_name)
|
||||
response = model.generate_content(formatted_msgs, stream=False, # type: ignore
|
||||
generation_config=genai.types.GenerationConfig(
|
||||
max_output_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k),
|
||||
safety_settings=safety_settings)
|
||||
return self.process_llm_response(response.text)
|
||||
except ValueError:
|
||||
if response:
|
||||
logger.error(f"Response from GCP Gen AI did not contain text. prompt: {system_msg} and user messages: {user_msgs}. Candidates: {response.candidates}")
|
||||
else:
|
||||
logger.error(f"There was no response from GCP Gen AI for prompt: {system_msg} and user messages: {user_msgs}")
|
||||
return None
|
||||
|
||||
# async def main():
|
||||
# from Agent_E.ae.core.prompts import LLM_PROMPTS
|
||||
# helper = GeminiLLMHelper()
|
||||
# response = await helper.get_chat_completion_response_async(LLM_PROMPTS["SKILLS_HARVESTING_PROMPT"], ["What is the weather like today?", "And How are you?"], temperature=0, max_tokens=4000)
|
||||
# print("*******\nResponse: ", response, "\n*******\n")
|
||||
|
||||
# asyncio.run(main())
|
529
Agent_E/ae/utils/get_detailed_accessibility_tree.py
Normal file
529
Agent_E/ae/utils/get_detailed_accessibility_tree.py
Normal file
|
@ -0,0 +1,529 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
|
||||
from playwright.async_api import Page
|
||||
|
||||
from Agent_E.ae.config import SOURCE_LOG_FOLDER_PATH
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
space_delimited_mmid = re.compile(r'^[\d ]+$')
|
||||
|
||||
def is_space_delimited_mmid(s: str) -> bool:
|
||||
"""
|
||||
Check if the given string matches the the mmid pattern of number space repeated.
|
||||
|
||||
Parameters:
|
||||
- s (str): The string to check against the pattern.
|
||||
|
||||
Returns:
|
||||
- bool: True if the string matches the pattern, False otherwise.
|
||||
"""
|
||||
# Use fullmatch() to ensure the entire string matches the pattern
|
||||
return bool(space_delimited_mmid.fullmatch(s))
|
||||
|
||||
|
||||
async def __inject_attributes(page: Page):
|
||||
"""
|
||||
Injects 'mmid' and 'aria-keyshortcuts' into all DOM elements. If an element already has an 'aria-keyshortcuts',
|
||||
it renames it to 'orig-aria-keyshortcuts' before injecting the new 'aria-keyshortcuts'
|
||||
This will be captured in the accessibility tree and thus make it easier to reconcile the tree with the DOM.
|
||||
'aria-keyshortcuts' is choosen because it is not widely used aria attribute.
|
||||
"""
|
||||
|
||||
last_mmid = await page.evaluate("""() => {
|
||||
const allElements = document.querySelectorAll('*');
|
||||
let id = 0;
|
||||
allElements.forEach(element => {
|
||||
const origAriaAttribute = element.getAttribute('aria-keyshortcuts');
|
||||
const mmid = `${++id}`;
|
||||
element.setAttribute('mmid', mmid);
|
||||
element.setAttribute('aria-keyshortcuts', mmid);
|
||||
//console.log(`Injected 'mmid'into element with tag: ${element.tagName} and mmid: ${mmid}`);
|
||||
if (origAriaAttribute) {
|
||||
element.setAttribute('orig-aria-keyshortcuts', origAriaAttribute);
|
||||
}
|
||||
});
|
||||
return id;
|
||||
}""")
|
||||
logger.debug(f"Added MMID into {last_mmid} elements")
|
||||
|
||||
|
||||
async def __fetch_dom_info(page: Page, accessibility_tree: dict[str, Any], only_input_fields: bool):
|
||||
"""
|
||||
Iterates over the accessibility tree, fetching additional information from the DOM based on 'mmid',
|
||||
and constructs a new JSON structure with detailed information.
|
||||
|
||||
Args:
|
||||
page (Page): The page object representing the web page.
|
||||
accessibility_tree (dict[str, Any]): The accessibility tree JSON structure.
|
||||
only_input_fields (bool): Flag indicating whether to include only input fields in the new JSON structure.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The pruned tree with detailed information from the DOM.
|
||||
"""
|
||||
|
||||
logger.debug("Reconciling the Accessibility Tree with the DOM")
|
||||
# Define the attributes to fetch for each element
|
||||
attributes = ['name', 'aria-label', 'placeholder', 'mmid', "id", "for", "data-testid"]
|
||||
backup_attributes = [] #if the attributes are not found, then try to get these attributes
|
||||
tags_to_ignore = ['head','style', 'script', 'link', 'meta', 'noscript', 'template', 'iframe', 'g', 'main', 'c-wiz','svg', 'path']
|
||||
attributes_to_delete = ["level", "multiline", "haspopup", "id", "for"]
|
||||
ids_to_ignore = ['agentDriveAutoOverlay']
|
||||
|
||||
# Recursive function to process each node in the accessibility tree
|
||||
async def process_node(node: dict[str, Any]):
|
||||
if 'children' in node:
|
||||
for child in node['children']:
|
||||
await process_node(child)
|
||||
|
||||
# Use 'name' attribute from the accessibility node as 'mmid'
|
||||
mmid_temp: str = node.get('keyshortcuts') # type: ignore
|
||||
|
||||
# If the name has multiple mmids, take the last one
|
||||
if(mmid_temp and is_space_delimited_mmid(mmid_temp)):
|
||||
#TODO: consider if we should grab each of the mmids and process them separately as seperate nodes copying this node's attributes
|
||||
mmid_temp = mmid_temp.split(' ')[-1]
|
||||
|
||||
#focusing on nodes with mmid, which is the attribute we inject
|
||||
try:
|
||||
mmid = int(mmid_temp)
|
||||
except (ValueError, TypeError):
|
||||
#logger.error(f"'name attribute contains \"{node.get('name')}\", which is not a valid numeric mmid. Adding node as is: {node}")
|
||||
return node.get('name')
|
||||
|
||||
if node['role'] == 'menuitem':
|
||||
return node.get('name')
|
||||
|
||||
if node.get('role') == 'dialog' and node.get('modal') == True: # noqa: E712
|
||||
node["important information"] = "This is a modal dialog. Please interact with this dialog and close it to be able to interact with the full page (e.g. by pressing the close button or selecting an option)."
|
||||
|
||||
if mmid:
|
||||
# Determine if we need to fetch 'innerText' based on the absence of 'children' in the accessibility node
|
||||
should_fetch_inner_text = 'children' not in node
|
||||
|
||||
js_code = """
|
||||
(input_params) => {
|
||||
const should_fetch_inner_text = input_params.should_fetch_inner_text;
|
||||
const mmid = input_params.mmid;
|
||||
const attributes = input_params.attributes;
|
||||
const tags_to_ignore = input_params.tags_to_ignore;
|
||||
const ids_to_ignore = input_params.ids_to_ignore;
|
||||
|
||||
const element = document.querySelector(`[mmid="${mmid}"]`);
|
||||
|
||||
if (!element) {
|
||||
console.log(`No element found with mmid: ${mmid}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (ids_to_ignore.includes(element.id)) {
|
||||
console.log(`Ignoring element with id: ${element.id}`, element);
|
||||
return null;
|
||||
}
|
||||
//Ignore "option" because it would have been processed with the select element
|
||||
if (tags_to_ignore.includes(element.tagName.toLowerCase()) || element.tagName.toLowerCase() === "option") return null;
|
||||
|
||||
let attributes_to_values = {
|
||||
'tag': element.tagName.toLowerCase() // Always include the tag name
|
||||
};
|
||||
|
||||
// If the element is an input, include its type as well
|
||||
if (element.tagName.toLowerCase() === 'input') {
|
||||
attributes_to_values['tag_type'] = element.type; // This will capture 'checkbox', 'radio', etc.
|
||||
}
|
||||
else if (element.tagName.toLowerCase() === 'select') {
|
||||
attributes_to_values["mmid"] = element.getAttribute('mmid');
|
||||
attributes_to_values["role"] = "combobox";
|
||||
attributes_to_values["options"] = [];
|
||||
|
||||
for (const option of element.options) {
|
||||
let option_attributes_to_values = {
|
||||
"mmid": option.getAttribute('mmid'),
|
||||
"text": option.text,
|
||||
"value": option.value,
|
||||
"selected": option.selected
|
||||
};
|
||||
attributes_to_values["options"].push(option_attributes_to_values);
|
||||
}
|
||||
return attributes_to_values;
|
||||
}
|
||||
|
||||
for (const attribute of attributes) {
|
||||
let value = element.getAttribute(attribute);
|
||||
|
||||
if(value){
|
||||
/*
|
||||
if(attribute === 'href'){
|
||||
value = value.split('?')[0]
|
||||
}
|
||||
*/
|
||||
attributes_to_values[attribute] = value;
|
||||
}
|
||||
}
|
||||
|
||||
if (should_fetch_inner_text && element.innerText) {
|
||||
attributes_to_values['description'] = element.innerText;
|
||||
}
|
||||
|
||||
let role = element.getAttribute('role');
|
||||
if(role==='listbox' || element.tagName.toLowerCase()=== 'ul'){
|
||||
let children=element.children;
|
||||
let filtered_children = Array.from(children).filter(child => child.getAttribute('role') === 'option');
|
||||
console.log("Listbox or ul found: ", filtered_children);
|
||||
let attributes_to_include = ['mmid', 'role', 'aria-label','value'];
|
||||
attributes_to_values["additional_info"]=[]
|
||||
for (const child of children) {
|
||||
let children_attributes_to_values = {};
|
||||
|
||||
for (let attr of child.attributes) {
|
||||
// If the attribute is not in the predefined list, add it to children_attributes_to_values
|
||||
if (attributes_to_include.includes(attr.name)) {
|
||||
children_attributes_to_values[attr.name] = attr.value;
|
||||
}
|
||||
}
|
||||
|
||||
attributes_to_values["additional_info"].push(children_attributes_to_values);
|
||||
}
|
||||
}
|
||||
// Check if attributes_to_values contains more than just 'name', 'role', and 'mmid'
|
||||
const keys = Object.keys(attributes_to_values);
|
||||
const minimalKeys = ['tag', 'mmid'];
|
||||
const hasMoreThanMinimalKeys = keys.length > minimalKeys.length || keys.some(key => !minimalKeys.includes(key));
|
||||
|
||||
if (!hasMoreThanMinimalKeys) {
|
||||
//If there were no attributes found, then try to get the backup attributes
|
||||
for (const backupAttribute of input_params.backup_attributes) {
|
||||
let value = element.getAttribute(backupAttribute);
|
||||
if(value){
|
||||
attributes_to_values[backupAttribute] = value;
|
||||
}
|
||||
}
|
||||
|
||||
//if even the backup attributes are not found, then return null, which will cause this element to be skipped
|
||||
if(Object.keys(attributes_to_values).length <= minimalKeys.length) {
|
||||
if (element.tagName.toLowerCase() === 'button') {
|
||||
attributes_to_values["mmid"] = element.getAttribute('mmid');
|
||||
attributes_to_values["role"] = "button";
|
||||
attributes_to_values["additional_info"] = [];
|
||||
let children=element.children;
|
||||
let attributes_to_exclude = ['width', 'height', 'path', 'class', 'viewBox', 'mmid']
|
||||
|
||||
// Check if the button has no text and no attributes
|
||||
if (element.innerText.trim() === '') {
|
||||
|
||||
for (const child of children) {
|
||||
let children_attributes_to_values = {};
|
||||
|
||||
for (let attr of child.attributes) {
|
||||
// If the attribute is not in the predefined list, add it to children_attributes_to_values
|
||||
if (!attributes_to_exclude.includes(attr.name)) {
|
||||
children_attributes_to_values[attr.name] = attr.value;
|
||||
}
|
||||
}
|
||||
|
||||
attributes_to_values["additional_info"].push(children_attributes_to_values);
|
||||
}
|
||||
console.log("Button with no text and no attributes: ", attributes_to_values);
|
||||
return attributes_to_values;
|
||||
}
|
||||
}
|
||||
|
||||
return null; // Return null if only minimal keys are present
|
||||
}
|
||||
}
|
||||
return attributes_to_values;
|
||||
}
|
||||
"""
|
||||
|
||||
# Fetch attributes and possibly 'innerText' from the DOM element by 'mmid'
|
||||
element_attributes = await page.evaluate(js_code,
|
||||
{"mmid": mmid, "attributes": attributes, "backup_attributes": backup_attributes,
|
||||
"should_fetch_inner_text": should_fetch_inner_text,
|
||||
"tags_to_ignore": tags_to_ignore,
|
||||
"ids_to_ignore": ids_to_ignore})
|
||||
|
||||
if 'keyshortcuts' in node:
|
||||
del node['keyshortcuts'] #remove keyshortcuts since it is not needed
|
||||
|
||||
node["mmid"]=mmid
|
||||
|
||||
# Update the node with fetched information
|
||||
if element_attributes:
|
||||
node.update(element_attributes)
|
||||
|
||||
# check if 'name' and 'mmid' are the same
|
||||
if node.get('name') == node.get('mmid') and node.get('role') != "textbox":
|
||||
del node['name'] # Remove 'name' from the node
|
||||
|
||||
if 'name' in node and 'description' in node and (node['name'] == node['description'] or node['name'] == node['description'].replace('\n', ' ') or node['description'].replace('\n', '') in node['name']):
|
||||
del node['description'] #if the name is same as description, then remove the description to avoid duplication
|
||||
|
||||
if 'name' in node and 'aria-label' in node and node['aria-label'] in node['name']:
|
||||
del node['aria-label'] #if the name is same as the aria-label, then remove the aria-label to avoid duplication
|
||||
|
||||
if 'name' in node and 'text' in node and node['name'] == node['text']:
|
||||
del node['text'] #if the name is same as the text, then remove the text to avoid duplication
|
||||
|
||||
if node.get('tag') == "select": #children are not needed for select menus since "options" attriburte is already added
|
||||
node.pop("children", None)
|
||||
node.pop("role", None)
|
||||
node.pop("description", None)
|
||||
|
||||
#role and tag can have the same info. Get rid of role if it is the same as tag
|
||||
if node.get('role') == node.get('tag'):
|
||||
del node['role']
|
||||
|
||||
# avoid duplicate aria-label
|
||||
if node.get("aria-label") and node.get("placeholder") and node.get("aria-label") == node.get("placeholder"):
|
||||
del node["aria-label"]
|
||||
|
||||
if node.get("role") == "link":
|
||||
del node["role"]
|
||||
if node.get("description"):
|
||||
node["text"] = node["description"]
|
||||
del node["description"]
|
||||
|
||||
#textbox just means a text input and that is expressed well enough with the rest of the attributes returned
|
||||
#if node.get('role') == "textbox":
|
||||
# del node['role']
|
||||
|
||||
if node.get('role') == "textbox":
|
||||
#get the id attribute of this field from the DOM
|
||||
if "id" in element_attributes and element_attributes["id"]:
|
||||
#find if there is an element in the DOM that has this id in aria-labelledby.
|
||||
js_code = """
|
||||
(inputParams) => {
|
||||
let referencingElements = [];
|
||||
const referencedElement = document.querySelector(`[aria-labelledby="${inputParams.aria_labelled_by_query_value}"]`);
|
||||
if(referencedElement) {
|
||||
const mmid = referencedElement.getAttribute('mmid');
|
||||
if (mmid) {
|
||||
return {"mmid": mmid, "tag": referencedElement.tagName.toLowerCase()};
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
"""
|
||||
#textbox just means a text input and that is expressed well enough with the rest of the attributes returned
|
||||
#del node['role']
|
||||
|
||||
#remove attributes that are not needed once processing of a node is complete
|
||||
for attribute_to_delete in attributes_to_delete:
|
||||
if attribute_to_delete in node:
|
||||
node.pop(attribute_to_delete, None)
|
||||
else:
|
||||
logger.debug(f"No element found with mmid: {mmid}, deleting node: {node}")
|
||||
node["marked_for_deletion_by_mm"] = True
|
||||
|
||||
|
||||
# Process each node in the tree starting from the root
|
||||
await process_node(accessibility_tree)
|
||||
|
||||
pruned_tree = __prune_tree(accessibility_tree, only_input_fields)
|
||||
|
||||
logger.debug("Reconciliation complete")
|
||||
return pruned_tree
|
||||
|
||||
|
||||
async def __cleanup_dom(page: Page):
|
||||
"""
|
||||
Cleans up the DOM by removing injected 'aria-description' attributes and restoring any original 'aria-keyshortcuts'
|
||||
from 'orig-aria-keyshortcuts'.
|
||||
"""
|
||||
logger.debug("Cleaning up the DOM's previous injections")
|
||||
await page.evaluate("""() => {
|
||||
const allElements = document.querySelectorAll('*[mmid]');
|
||||
allElements.forEach(element => {
|
||||
element.removeAttribute('aria-keyshortcuts');
|
||||
const origAriaLabel = element.getAttribute('orig-aria-keyshortcuts');
|
||||
if (origAriaLabel) {
|
||||
element.setAttribute('aria-keyshortcuts', origAriaLabel);
|
||||
element.removeAttribute('orig-aria-keyshortcuts');
|
||||
}
|
||||
});
|
||||
}""")
|
||||
logger.debug("DOM cleanup complete")
|
||||
|
||||
|
||||
def __prune_tree(node: dict[str, Any], only_input_fields: bool) -> dict[str, Any] | None:
|
||||
"""
|
||||
Recursively prunes a tree starting from `node`, based on pruning conditions and handling of 'unraveling'.
|
||||
|
||||
The function has two main jobs:
|
||||
1. Pruning: Remove nodes that don't meet certain conditions, like being marked for deletion.
|
||||
2. Unraveling: For nodes marked with 'marked_for_unravel_children', we replace them with their children,
|
||||
effectively removing the node and lifting its children up a level in the tree.
|
||||
|
||||
This happens in place, meaning we modify the tree as we go, which is efficient but means you should
|
||||
be cautious about modifying the tree outside this function during a prune operation.
|
||||
|
||||
Args:
|
||||
- node (Dict[str, Any]): The node we're currently looking at. We'll check this node, its children,
|
||||
and so on, recursively down the tree.
|
||||
- only_input_fields (bool): If True, we're only interested in pruning input-related nodes (like form fields).
|
||||
This lets you narrow the focus if, for example, you're only interested in cleaning up form-related parts
|
||||
of a larger tree.
|
||||
|
||||
Returns:
|
||||
- dict[str, Any] | None: The pruned version of `node`, or None if `node` was pruned away. When we 'unravel'
|
||||
a node, we directly replace it with its children in the parent's list of children, so the return value
|
||||
will be the parent, updated in place.
|
||||
|
||||
Notes:
|
||||
- 'marked_for_deletion_by_mm' is our flag for nodes that should definitely be removed.
|
||||
- Unraveling is neat for flattening the tree when a node is just a wrapper without semantic meaning.
|
||||
- We use a while loop with manual index management to safely modify the list of children as we iterate over it.
|
||||
"""
|
||||
if "marked_for_deletion_by_mm" in node:
|
||||
return None
|
||||
|
||||
if 'children' in node:
|
||||
i = 0
|
||||
while i < len(node['children']):
|
||||
child = node['children'][i]
|
||||
if 'marked_for_unravel_children' in child:
|
||||
# Replace the current child with its children
|
||||
if 'children' in child:
|
||||
node['children'] = node['children'][:i] + child['children'] + node['children'][i+1:]
|
||||
i += len(child['children']) - 1 # Adjust the index for the new children
|
||||
else:
|
||||
# If the node marked for unraveling has no children, remove it
|
||||
node['children'].pop(i)
|
||||
i -= 1 # Adjust the index since we removed an element
|
||||
else:
|
||||
# Recursively prune the child if it's not marked for unraveling
|
||||
pruned_child = __prune_tree(child, only_input_fields)
|
||||
if pruned_child is None:
|
||||
# If the child is pruned, remove it from the children list
|
||||
node['children'].pop(i)
|
||||
i -= 1 # Adjust the index since we removed an element
|
||||
else:
|
||||
# Update the child with the pruned version
|
||||
node['children'][i] = pruned_child
|
||||
i += 1 # Move to the next child
|
||||
|
||||
# After processing all children, if the children array is empty, remove it
|
||||
if not node['children']:
|
||||
del node['children']
|
||||
|
||||
# Apply existing conditions to decide if the current node should be pruned
|
||||
return None if __should_prune_node(node, only_input_fields) else node
|
||||
|
||||
|
||||
def __should_prune_node(node: dict[str, Any], only_input_fields: bool):
|
||||
"""
|
||||
Determines if a node should be pruned based on its 'role' and 'element_attributes'.
|
||||
|
||||
Args:
|
||||
node (dict[str, Any]): The node to be evaluated.
|
||||
only_input_fields (bool): Flag indicating whether only input fields should be considered.
|
||||
|
||||
Returns:
|
||||
bool: True if the node should be pruned, False otherwise.
|
||||
"""
|
||||
#If the request is for only input fields and this is not an input field, then mark the node for prunning
|
||||
if node.get("role") != "WebArea" and only_input_fields and not (node.get("tag") in ("input", "button", "textarea") or node.get("role") == "button"):
|
||||
return True
|
||||
|
||||
if node.get('role') == 'generic' and 'children' not in node and not ('name' in node and node.get('name')): # The presence of 'children' is checked after potentially deleting it above
|
||||
return True
|
||||
|
||||
if node.get('role') in ['separator', 'LineBreak']:
|
||||
return True
|
||||
processed_name = ""
|
||||
if 'name' in node:
|
||||
processed_name:str =node.get('name') # type: ignore
|
||||
processed_name = processed_name.replace(',', '')
|
||||
processed_name = processed_name.replace(':', '')
|
||||
processed_name = processed_name.replace('\n', '')
|
||||
processed_name = processed_name.strip()
|
||||
if len(processed_name) <3:
|
||||
processed_name = ""
|
||||
|
||||
#check if the node only have name and role, then delete that node
|
||||
if len(node) == 2 and 'name' in node and 'role' in node and not (node.get('role') == "text" and processed_name != ""):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_node_dom_element(page: Page, mmid: str):
|
||||
return await page.evaluate("""
|
||||
(mmid) => {
|
||||
return document.querySelector(`[mmid="${mmid}"]`);
|
||||
}
|
||||
""", mmid)
|
||||
|
||||
|
||||
async def get_element_attributes(page: Page, mmid: str, attributes: list[str]):
|
||||
return await page.evaluate("""
|
||||
(inputParams) => {
|
||||
const mmid = inputParams.mmid;
|
||||
const attributes = inputParams.attributes;
|
||||
const element = document.querySelector(`[mmid="${mmid}"]`);
|
||||
if (!element) return null; // Return null if element is not found
|
||||
|
||||
let attrs = {};
|
||||
for (let attr of attributes) {
|
||||
attrs[attr] = element.getAttribute(attr);
|
||||
}
|
||||
return attrs;
|
||||
}
|
||||
""", {"mmid": mmid, "attributes": attributes})
|
||||
|
||||
|
||||
async def get_dom_with_accessibility_info() -> Annotated[dict[str, Any] | None, "A minified representation of the HTML DOM for the current webpage"]:
|
||||
"""
|
||||
Retrieves, processes, and minifies the Accessibility tree of the active page in a browser instance.
|
||||
Strictly follow the name and role tag for any interaction with the nodes.
|
||||
|
||||
Returns:
|
||||
- The minified JSON content of the browser's active page.
|
||||
"""
|
||||
logger.debug("Executing Get Accessibility Tree Command")
|
||||
# Create and use the PlaywrightManager
|
||||
browser_manager = PlaywrightManager(browser_type='chromium', headless=False)
|
||||
page = await browser_manager.get_current_page()
|
||||
if page is None: # type: ignore
|
||||
raise ValueError('No active page found')
|
||||
|
||||
return await do_get_accessibility_info(page)
|
||||
|
||||
|
||||
async def do_get_accessibility_info(page: Page, only_input_fields: bool = False):
|
||||
"""
|
||||
Retrieves the accessibility information of a web page and saves it as JSON files.
|
||||
|
||||
Args:
|
||||
page (Page): The page object representing the web page.
|
||||
only_input_fields (bool, optional): If True, only retrieves accessibility information for input fields.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
dict[str, Any] or None: The enhanced accessibility tree as a dictionary, or None if an error occurred.
|
||||
"""
|
||||
await __inject_attributes(page)
|
||||
accessibility_tree: dict[str, Any] = await page.accessibility.snapshot(interesting_only=True) # type: ignore
|
||||
|
||||
with open(os.path.join(SOURCE_LOG_FOLDER_PATH, 'json_accessibility_dom.json'), 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(accessibility_tree, indent=2))
|
||||
logger.debug("json_accessibility_dom.json saved")
|
||||
|
||||
await __cleanup_dom(page)
|
||||
try:
|
||||
enhanced_tree = await __fetch_dom_info(page, accessibility_tree, only_input_fields)
|
||||
|
||||
logger.debug("Enhanced Accessibility Tree ready")
|
||||
|
||||
with open(os.path.join(SOURCE_LOG_FOLDER_PATH, 'json_accessibility_dom_enriched.json'), 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(enhanced_tree, indent=2))
|
||||
logger.debug("json_accessibility_dom_enriched.json saved")
|
||||
|
||||
return enhanced_tree
|
||||
except Exception as e:
|
||||
logger.error(f"Error while fetching DOM info: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
43
Agent_E/ae/utils/http_helper.py
Normal file
43
Agent_E/ae/utils/http_helper.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def make_post_request(url: str, data: dict[str, Any], api_key: str, api_key_header_name: str = "apikey") -> dict[str, Any]|None:
|
||||
"""
|
||||
Makes a POST request to the specified URL with a JSON body and an API key header.
|
||||
|
||||
Args:
|
||||
url (str): The URL to send the POST request to.
|
||||
data (Dict[str, Any]): The JSON data to include in the POST request body.
|
||||
api_key (str): The API key to include in the request headers.
|
||||
api_key_header_name (str): The name of the header to include the API key in. Defaults to "apikey".
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The JSON response from the server if the request was successful and the response is in JSON format.
|
||||
None: If the request failed or the response is not in JSON format.
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If an error occurs during the HTTP request.
|
||||
"""
|
||||
# Define the headers for the request
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
api_key_header_name: api_key
|
||||
}
|
||||
|
||||
try:
|
||||
# Make the POST request with the given URL, data, and headers
|
||||
response = requests.post(url, json=data, headers=headers)
|
||||
|
||||
# Check if the request was successful
|
||||
response.raise_for_status()
|
||||
|
||||
# Attempt to return the JSON response
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
except ValueError:
|
||||
print("Error: Response is not in JSON format")
|
||||
return None
|
34
Agent_E/ae/utils/js_helper.py
Normal file
34
Agent_E/ae/utils/js_helper.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import json
|
||||
import re
|
||||
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
def escape_js_message(message: str) -> str:
|
||||
"""
|
||||
Escape a message for use in JavaScript code.
|
||||
|
||||
Args:
|
||||
message (str): The message to escape.
|
||||
|
||||
Returns:
|
||||
str: The escaped message.
|
||||
"""
|
||||
return json.dumps(message)
|
||||
|
||||
|
||||
def beautify_plan_message(message:str) -> str:
|
||||
"""
|
||||
Add a newline between each numbered step in the plan message if it does not already exist.
|
||||
|
||||
Args:
|
||||
message (str): The plan message.
|
||||
|
||||
Returns:
|
||||
str: The plan message with newlines added between each numbered step.
|
||||
"""
|
||||
logger.debug(f"beautify_plan_message original:\n{message}")
|
||||
# Add a newline before each numbered step that is not already preceded by a newline
|
||||
plan_with_newlines = re.sub(r'(?<!\n)( \d+\.)', r'\n\1', message)
|
||||
logger.debug(f"beautify_plan_message modified:\n{plan_with_newlines}")
|
||||
return plan_with_newlines
|
71
Agent_E/ae/utils/logger.py
Normal file
71
Agent_E/ae/utils/logger.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pythonjsonlogger import jsonlogger
|
||||
|
||||
# Load environment variables from a .env file
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Custom function to configure the logger
|
||||
def configure_logger(level: str = "INFO") -> None:
|
||||
log_format = os.getenv("LOG_MESSAGES_FORMAT", "text").lower()
|
||||
|
||||
# Set log level for the main logger
|
||||
logger.setLevel(level.upper())
|
||||
|
||||
# Create a handler for logging
|
||||
handler = logging.StreamHandler()
|
||||
|
||||
if log_format == "json":
|
||||
# JSON format
|
||||
formatter = jsonlogger.JsonFormatter(
|
||||
fmt='%(asctime)s %(name)s %(levelname)s %(message)s %(filename)s %(lineno)d',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
else:
|
||||
# Text format
|
||||
formatter = logging.Formatter(
|
||||
fmt='[%(asctime)s] %(levelname)s {%(filename)s:%(lineno)d} - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
logger.handlers = [] # Clear existing handlers
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Ensure other loggers have the same handler
|
||||
http_loggers = ["openai", "autogen"]
|
||||
for http_logger in http_loggers:
|
||||
lib_logger = logging.getLogger(http_logger)
|
||||
lib_logger.setLevel(logging.DEBUG)
|
||||
lib_logger.handlers = [] # Clear any existing handlers
|
||||
lib_logger.addHandler(handler) # Add the same handler
|
||||
|
||||
|
||||
# Call the configure logger function to set up the logger initially
|
||||
configure_logger(level="INFO")
|
||||
|
||||
# Function to set log level
|
||||
def set_log_level(level: str) -> None:
|
||||
"""
|
||||
Set the log level for the logger.
|
||||
|
||||
Parameters:
|
||||
- level (str): A logging level such as 'debug', 'info', 'warning', 'error', or 'critical'.
|
||||
"""
|
||||
configure_logger(level)
|
||||
|
||||
# Set default log levels for other libraries
|
||||
# logging.getLogger("httpcore").setLevel(logging.DEBUG)
|
||||
# logging.getLogger("httpx").setLevel(logging.DEBUG)
|
||||
# logging.getLogger("openai").setLevel(logging.DEBUG)
|
||||
# logging.getLogger("autogen").setLevel(logging.DEBUG)
|
||||
logging.getLogger("matplotlib.pyplot").setLevel(logging.WARNING)
|
||||
logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING)
|
||||
logging.getLogger("PIL.Image").setLevel(logging.WARNING)
|
||||
|
||||
# Re-export the logger for ease of use
|
||||
__all__ = ["logger", "set_log_level"]
|
51
Agent_E/ae/utils/openai_llm_helper.py
Normal file
51
Agent_E/ae/utils/openai_llm_helper.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
from dotenv import load_dotenv
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class OpenAILLMHelper:
|
||||
def __init__(self):
|
||||
load_dotenv()
|
||||
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
async def get_chat_completion_response_async(self, system_msg:str, user_msgs:list[str], model_name:str="gpt-4-turbo-preview", temperature:float=0.1, max_tokens:int=256, frequency_penalty:float=0.0, top_p: float=1.0, top_k: int=1, presence_penalty: float=0.0):
|
||||
formatted_msgs: list[dict[str, Any]] = [{"role": "system", "content": system_msg}]
|
||||
|
||||
for user_msg in user_msgs:
|
||||
formatted_msgs.append({"role": "user", "content": user_msg})
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model_name,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
frequency_penalty=frequency_penalty,
|
||||
top_p=top_p,
|
||||
presence_penalty=presence_penalty,
|
||||
messages=formatted_msgs # type: ignore
|
||||
)
|
||||
print(">>> openai response:", response)
|
||||
if response.choices and len(response.choices) > 0 and response.choices[0].message and response.choices[0].message.content:
|
||||
return response.choices[0].message.content
|
||||
return None
|
||||
except openai.APIConnectionError as e:
|
||||
print("The server could not be reached")
|
||||
print(e.__cause__) # an underlying Exception, likely raised within httpx.
|
||||
raise Exception(f"Calling {model_name} LLM failed. The server could not be reached.") from e
|
||||
except openai.RateLimitError as e:
|
||||
print("A 429 status code was received; we should back off a bit.")
|
||||
raise Exception(f"Calling {model_name} LLM failed. Rate limit error.") from e
|
||||
except openai.APIStatusError as e:
|
||||
print(e.status_code)
|
||||
print(e.response)
|
||||
raise Exception(f"Calling {model_name} LLM failed. Error: {e}") from e
|
||||
|
||||
# async def main():
|
||||
# helper = OpenAILLMHelper()
|
||||
# response = await helper.get_chat_completion_response_async(LLM_PROMPTS["SKILLS_HARVESTING_PROMPT"], ["What is the weather like today?"], temperature=0, max_tokens=4000)
|
||||
# print("*******\nResponse: ", response, "\n*******\n")
|
||||
|
||||
# asyncio.run(main())
|
60
Agent_E/ae/utils/response_parser.py
Normal file
60
Agent_E/ae/utils/response_parser.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
|
||||
|
||||
def parse_response(message: str) -> dict[str, Any]:
|
||||
"""
|
||||
Parse the response from the browser agent and return the response as a dictionary.
|
||||
"""
|
||||
# Parse the response content
|
||||
json_response = {}
|
||||
#if message starts with ``` and ends with ``` then remove them
|
||||
if message.startswith("```"):
|
||||
message = message[3:]
|
||||
if message.endswith("```"):
|
||||
message = message[:-3]
|
||||
if message.startswith("json"):
|
||||
message = message[4:]
|
||||
|
||||
message = message.strip()
|
||||
try:
|
||||
json_response: dict[str, Any] = json.loads(message)
|
||||
except Exception as e:
|
||||
# If the response is not a valid JSON, try pass it using string matching.
|
||||
#This should seldom be triggered
|
||||
logger.warn(f"LLM response was not properly formed JSON. Will try to use it as is. LLM response: \"{message}\". Error: {e}")
|
||||
message = message.replace("\\n", "\n")
|
||||
message = message.replace("\n", " ") # type: ignore
|
||||
if ("plan" in message and "next_step" in message):
|
||||
start = message.index("plan") + len("plan")
|
||||
end = message.index("next_step")
|
||||
json_response["plan"] = message[start:end].replace('"', '').strip()
|
||||
if ("next_step" in message and "terminate" in message):
|
||||
start = message.index("next_step") + len("next_step")
|
||||
end = message.index("terminate")
|
||||
json_response["next_step"] = message[start:end].replace('"', '').strip()
|
||||
if ("terminate" in message and "final_response" in message):
|
||||
start = message.index("terminate") + len("terminate")
|
||||
end = message.index("final_response")
|
||||
matched_string=message[start:end].replace('"', '').strip()
|
||||
if ("yes" in matched_string):
|
||||
json_response["terminate"] = "yes"
|
||||
else:
|
||||
json_response["terminate"] = "no"
|
||||
|
||||
start=message.index("final_response") + len("final_response")
|
||||
end=len(message)-1
|
||||
json_response["final_response"] = message[start:end].replace('"', '').strip()
|
||||
|
||||
elif ("terminate" in message):
|
||||
start = message.index("terminate") + len("terminate")
|
||||
end = len(message)-1
|
||||
matched_string=message[start:end].replace('"', '').strip()
|
||||
if ("yes" in matched_string):
|
||||
json_response["terminate"] = "yes"
|
||||
else:
|
||||
json_response["terminate"] = "no"
|
||||
|
||||
return json_response
|
14
Agent_E/ae/utils/ui_messagetype.py
Normal file
14
Agent_E/ae/utils/ui_messagetype.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
PLAN = "plan"
|
||||
STEP = "step"
|
||||
ACTION ="action"
|
||||
ANSWER = "answer"
|
||||
QUESTION = "question"
|
||||
INFO = "info"
|
||||
FINAL = "final"
|
||||
DONE = "transaction_done"
|
||||
ERROR = "error"
|
||||
MAX_TURNS_REACHED = "max_turns_reached"
|
437
Agent_E/test/evaluators.py
Normal file
437
Agent_E/test/evaluators.py
Normal file
|
@ -0,0 +1,437 @@
|
|||
"""base class for evaluation"""
|
||||
import collections
|
||||
import html
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from Agent_E.test.test_utils import clean_answer
|
||||
from Agent_E.test.test_utils import evaluate_exact_match
|
||||
from Agent_E.test.test_utils import evaluate_fuzzy_match
|
||||
from Agent_E.test.test_utils import evaluate_must_include
|
||||
from Agent_E.test.test_utils import evaluate_ua_match
|
||||
from typing import Any
|
||||
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from playwright.sync_api import CDPSession
|
||||
from playwright.sync_api import Page
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""Base class for evaluation strategies.
|
||||
|
||||
Attributes:
|
||||
eval_tag (str): A tag to identify or categorize the evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, eval_tag: str = "") -> None:
|
||||
"""Initialize the evaluator with an optional evaluation tag."""
|
||||
self.eval_tag = eval_tag
|
||||
|
||||
async def __call__(self, task_config: dict[str, Any], page: Page, client: CDPSession, answer: str) -> dict[str, float|str]:
|
||||
"""Abstract method to be implemented by subclasses for evaluation.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be overridden by subclasses.
|
||||
"""
|
||||
raise NotImplementedError("This method should be overridden by subclasses.")
|
||||
|
||||
|
||||
class StringEvaluator(Evaluator):
|
||||
"""Evaluates string-based answers using various matching criteria.
|
||||
|
||||
Supports exact matches, some matches, fuzzy matching using LLM, and unachievable task matching.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
task_config: dict[str, Any],
|
||||
page: Page | None = None,
|
||||
client: CDPSession | None = None,
|
||||
answer: str | None = None,
|
||||
|
||||
) -> dict[str, float|str]:
|
||||
last_action = answer or ""
|
||||
pred = clean_answer(last_action)
|
||||
|
||||
score = 1.0
|
||||
for approach, value in task_config["eval"]["reference_answers"].items():
|
||||
|
||||
match approach:
|
||||
case "exact_match":
|
||||
logger.info(f"Evaluating exact_match for answer: Predicted: {pred} , Reference: {value}")
|
||||
score *= evaluate_exact_match(ref=value, pred=pred)
|
||||
|
||||
case "must_include":
|
||||
logger.info(f"Evaluating must_include for answer: \"{answer}\" to see if it includes the expeced values: \"{value}\"\n")
|
||||
assert isinstance(value, list)
|
||||
for must_value in value: # type: ignore
|
||||
score *= evaluate_must_include(
|
||||
ref=must_value, # type: ignore
|
||||
pred=pred,
|
||||
tokenize=(len(value) == 1), # type: ignore
|
||||
)
|
||||
case "some_matches":
|
||||
min_required_matches = value.get("min_required", 1)
|
||||
matches = sum(evaluate_must_include(ref=phrase, pred=pred, tokenize=False) for phrase in value["phrases"])
|
||||
score *= float(matches >= min_required_matches)
|
||||
case "fuzzy_match":
|
||||
logger.info(f"Evaluating fuzzy_match for answer: {answer}")
|
||||
intent = task_config["intent"]
|
||||
if value == "N/A":
|
||||
# if the instruction only asks the model to generate N/A when encountering an unachievable task
|
||||
# without more concrete reasons
|
||||
score *= evaluate_exact_match(ref=value, pred=pred)
|
||||
# if the instruction also asks the model to generate the reason why the task is unachievable
|
||||
# this should be the default as it will prevent false positive N/A`
|
||||
if score != 1:
|
||||
score = 1.0 * evaluate_ua_match(
|
||||
intent=task_config["intent"],
|
||||
ref=task_config["eval"]["string_note"],
|
||||
pred=pred,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Evaluating generic for answer: {answer}")
|
||||
assert isinstance(value, list)
|
||||
for reference in value: # type: ignore
|
||||
score *= evaluate_fuzzy_match(
|
||||
ref=reference, pred=pred, intent=intent # type: ignore
|
||||
)
|
||||
case _:
|
||||
logger.info(f"Unknown approach value received: {approach}")
|
||||
return {"score": score}
|
||||
|
||||
|
||||
class URLEvaluator(Evaluator):
|
||||
"""Evaluates if the given URL matches the expected URL criteria defined in the configuration.
|
||||
|
||||
This includes checking if the base path of the URL and its query parameters match those specified in the reference URLs.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
task_config: dict[str, Any],
|
||||
page: Page,
|
||||
client: CDPSession | None = None,
|
||||
answer: str | None = None
|
||||
) -> dict[str, float|str]:
|
||||
"""Evaluates the current page URL against reference URLs specified in the config file.
|
||||
|
||||
Parameters:
|
||||
task_config (dict[str, Any]): The task configuration containing evaluation criteria.
|
||||
page (Page): The Playwright page object for the current webpage.
|
||||
client (CDPSession | None, optional): The Chrome DevTools Protocol session object. Not used in this evaluator.
|
||||
answer (str | None, optional): Not used in this evaluator.
|
||||
|
||||
Returns:
|
||||
dict[str, float|str]: "score" 1.0 if the page URL matches any of the reference URLs, considering the matching rule; otherwise 0.0.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unknown matching rule is specified in the config file.
|
||||
"""
|
||||
|
||||
def clean_url(url: str) -> str:
|
||||
url = str(url)
|
||||
url = url.rstrip("/")
|
||||
url = url.lower()
|
||||
return url
|
||||
|
||||
def parse_url(url: str) -> tuple[str, dict[str, list[str]]]:
|
||||
"""Parse a URL into its base, path, and query components."""
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
base_path = parsed_url.netloc + parsed_url.path
|
||||
query = urllib.parse.parse_qs(parsed_url.query)
|
||||
return base_path, query
|
||||
|
||||
def parse_urls(
|
||||
urls: list[str],
|
||||
) -> tuple[list[str], dict[str, set[str]]]:
|
||||
"""Parse a list of URLs."""
|
||||
base_paths: list[str] = []
|
||||
queries: dict[str, set[str]] = collections.defaultdict(set)
|
||||
for url in urls:
|
||||
base_path, query = parse_url(url)
|
||||
base_paths.append(base_path)
|
||||
for k, v in query.items():
|
||||
queries[k].update(v)
|
||||
return base_paths, queries
|
||||
|
||||
pred = clean_url(page.url)
|
||||
ref_urls = task_config["eval"]["reference_url"].split(" |OR| ")
|
||||
ref_urls = [clean_url(url) for url in ref_urls]
|
||||
matching_rule = task_config["eval"].get("url_note", "GOLD in PRED")
|
||||
if matching_rule == "GOLD in PRED":
|
||||
for ref_url in ref_urls:
|
||||
ref_base_path, ref_query = parse_url(ref_url)
|
||||
pred_base_paths, pred_query = parse_url(pred)
|
||||
# Calculate base score for each ref_url
|
||||
base_score = float(ref_base_path in pred_base_paths)
|
||||
# Calculate query score for each ref_url
|
||||
query_score = 1.0
|
||||
for k, possible_values in ref_query.items(): # type: ignore
|
||||
if k in pred_query:
|
||||
query_score *= float(
|
||||
any(
|
||||
possible_ref_value in pred_query.get(k, []) # type: ignore
|
||||
for possible_ref_value in possible_values # type: ignore
|
||||
)
|
||||
)
|
||||
else:
|
||||
# If the key is not in pred_query, check if the reference URL has no query parameters
|
||||
if not possible_values:
|
||||
query_score *= 1.0 # No query parameters to match, so consider it a match
|
||||
else:
|
||||
query_score *= 0.0 # Reference URL has query parameters but predicted URL does not
|
||||
# Calculate final score for each ref_url
|
||||
score = base_score * query_score
|
||||
# Return immediately if any score is 1
|
||||
if score == 1.0:
|
||||
return {"score": score}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown matching rule: {matching_rule}")
|
||||
|
||||
return {"score": 0.0}
|
||||
|
||||
|
||||
class HTMLContentEvaluator(Evaluator):
|
||||
"""Evaluates if specified HTML content or elements appear on the webpage.
|
||||
|
||||
This involves navigating to URLs specified in the configuration and checking for the presence of HTML elements or content using various strategies.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
task_config: dict[str, Any],
|
||||
page: Page,
|
||||
client: CDPSession | None = None,
|
||||
answer: str | None = None
|
||||
) -> dict[str, float|str]:
|
||||
"""Evaluates the presence of specified HTML content on the webpage.
|
||||
|
||||
Parameters:
|
||||
task_config (dict[str, Any]): The task configuration containing evaluation criteria.
|
||||
page (Page): The Playwright page object for the current webpage.
|
||||
client (CDPSession | None, optional): The Chrome DevTools Protocol session object. Not used in this evaluator.
|
||||
answer (str | None, optional): Not used in this evaluator.
|
||||
|
||||
Returns:
|
||||
dict[str, float|str]: "score" A score between 0.0 and 1.0 representing the presence of required HTML content on the webpage.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unknown locator strategy is specified in the config file.
|
||||
"""
|
||||
targets = task_config["eval"]["program_html"]
|
||||
|
||||
score = 1.0
|
||||
for target in targets:
|
||||
target_url: str = target["url"] # which url to check
|
||||
if target_url.startswith("func"):
|
||||
func = target_url.split("func:")[1]
|
||||
func = func.replace("__last_url__", page.url)
|
||||
target_url = eval(func)
|
||||
|
||||
locator: str = target["locator"] # js element locator
|
||||
|
||||
# navigate to that url
|
||||
if target_url != "last":
|
||||
page.goto(target_url)
|
||||
time.sleep(3)
|
||||
|
||||
# empty, use the full page
|
||||
if not locator.strip():
|
||||
selected_element = page.content()
|
||||
# use JS to select the element
|
||||
elif locator.startswith("document.") or locator.startswith("[...document.") or locator.startswith("jsblock:"):
|
||||
if "prep_actions" in target:
|
||||
try:
|
||||
for prep_action in target["prep_actions"]:
|
||||
page.evaluate(f"() => {prep_action}")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if locator.startswith("jsblock:"):
|
||||
locator = locator.split("jsblock:")[1]
|
||||
|
||||
selected_element = str(await page.evaluate(f"() => {locator}"))
|
||||
if not selected_element:
|
||||
selected_element = ""
|
||||
except Exception:
|
||||
# the page is wrong, return empty
|
||||
selected_element = ""
|
||||
# run program to call API
|
||||
elif locator.startswith("func:"): # a helper function
|
||||
func = locator.split("func:")[1]
|
||||
func = func.replace("__page__", "page")
|
||||
selected_element = eval(func)
|
||||
else:
|
||||
raise ValueError(f"Unknown locator: {locator}")
|
||||
|
||||
selected_element = html.unescape(selected_element)
|
||||
|
||||
if "exact_match" in target["required_contents"]:
|
||||
required_contents = target["required_contents"]["exact_match"]
|
||||
cur_score = evaluate_exact_match(
|
||||
ref=required_contents, pred=selected_element
|
||||
)
|
||||
score *= float(cur_score)
|
||||
# logger.info(f"[exact match] {cur_score}, selected element: {selected_element}, required contents: {required_contents}")
|
||||
elif "must_include" in target["required_contents"]:
|
||||
required_contents = target["required_contents"]["must_include"]
|
||||
assert isinstance(required_contents, list)
|
||||
for content in required_contents: # type: ignore
|
||||
content_or = content.split(" |OR| ") # type: ignore
|
||||
cur_score = any(
|
||||
[
|
||||
evaluate_must_include(
|
||||
ref=content, # type: ignore
|
||||
pred=selected_element,
|
||||
tokenize=False,
|
||||
)
|
||||
for content in content_or # type: ignore
|
||||
]
|
||||
)
|
||||
score *= float(cur_score)
|
||||
# logger.info(f"[must include] {cur_score}, selected element: {selected_element}, required contents: {content_or}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown required_contents: {target['required_contents'].keys()}"
|
||||
)
|
||||
return {"score": score}
|
||||
|
||||
class ManualContentEvaluator(Evaluator):
|
||||
"""Evaluation Route for Manual Evaluation."""
|
||||
async def __call__(
|
||||
self,
|
||||
task_config: dict[str, Any],
|
||||
page: Page,
|
||||
client: CDPSession | None = None,
|
||||
answer: str | None = None
|
||||
) -> dict[str, float|str]:
|
||||
"""Pauses Execution to get manual evaluation score from user.
|
||||
|
||||
Parameters:
|
||||
task_config (dict[str, Any]): The task configuration containing evaluation criteria.
|
||||
page (Page): The Playwright page object for the current webpage.
|
||||
client (CDPSession | None, optional): The Chrome DevTools Protocol session object. Not used in this evaluator.
|
||||
answer (str | None, optional): Not used in this evaluator.
|
||||
|
||||
Returns:
|
||||
dict[str, float|str]: A score representig the status 1 = pass, 0 = fail and -0.1 is a skip. Additionaly, a reason can be provided for the score (mainly for fail/skip).
|
||||
"""
|
||||
task = task_config["intent"]
|
||||
reference_answer = task_config["eval"]["reference_answers"]["manual_check"]["answer"]
|
||||
answer_type = task_config["eval"]["reference_answers"]["manual_check"]["type"]
|
||||
id = str(task_config["task_id"])
|
||||
index = str(task_config["task_index"])
|
||||
|
||||
print(colored("\n\n***************************\n", "green", attrs=["bold"]))
|
||||
print(colored("Task ID: ", "blue", attrs=["bold"]) + id + "\n")
|
||||
print(colored("Task Index: ", "blue", attrs=["bold"]) + index + "\n")
|
||||
print(colored("Task: ", "blue", attrs=["bold"]) + task + "\n")
|
||||
print(colored("Agent answer: ", "blue", attrs=["bold"]) + str(answer or "") + "\n")
|
||||
|
||||
if answer_type.strip().lower() == "possible":
|
||||
print(colored("Possible answer (reference): ", "yellow") + f"~~~{reference_answer}~~~")
|
||||
elif answer_type.strip().lower() == "golden":
|
||||
print(colored("Golden answer (reference): ", "yellow") + reference_answer)
|
||||
|
||||
user_response = input(colored("Annotate the task as Pass, Fail or Skip (please use Skip sparingly)? ", "magenta", attrs=["bold"]))
|
||||
eval_response: dict[str, float|str] = {}
|
||||
if(user_response.lower()=="pass"):
|
||||
eval_response["score"] = 1.0
|
||||
elif user_response.lower()=="fail":
|
||||
eval_response["score"] = 0.0
|
||||
elif user_response.lower()=="skip":
|
||||
eval_response["score"] = -0.1
|
||||
else:
|
||||
print(colored(f"Received response: {user_response}", "red"))
|
||||
raise ValueError("Invalid user response. Please enter 'Pass', 'Fail' or 'Skip'.")
|
||||
reason: str|None = None
|
||||
|
||||
if eval_response["score"] <= 0:
|
||||
reason = input("Reason for rating: ")
|
||||
eval_response["reason"] = reason
|
||||
|
||||
return eval_response
|
||||
|
||||
class EvaluatorComb(Evaluator):
|
||||
"""Combines multiple evaluators to perform a comprehensive evaluation based on different criteria.
|
||||
|
||||
Attributes:
|
||||
evaluators (list[Evaluator]): A list of evaluator instances to be used for evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self, evaluators: list[Evaluator]) -> None:
|
||||
"""Initializes the composite evaluator with a list of individual evaluators.
|
||||
|
||||
Parameters:
|
||||
evaluators (list[Evaluator]): The list of evaluators to include in the composite evaluation.
|
||||
"""
|
||||
self.evaluators = evaluators
|
||||
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
task_config: dict[str, Any],
|
||||
page: Page,
|
||||
client: CDPSession,
|
||||
answer: str,
|
||||
) -> dict[str, float|str]:
|
||||
"""Performs the evaluation using all included evaluators and aggregates their scores.
|
||||
|
||||
Parameters:
|
||||
task_config (dict[str, Any]): The task configuration containing evaluation criteria.
|
||||
page (Page): The Playwright page object for the current webpage.
|
||||
client (CDPSession): The Chrome DevTools Protocol session object.
|
||||
answer (str): The answer or content to be evaluated.
|
||||
|
||||
Returns:
|
||||
dict[str, float|str]: "score" - The aggregated score from all evaluators, representing the overall evaluation result. "reason" - The reason for the evaluation score, if applicable.
|
||||
"""
|
||||
score: float = 1.0
|
||||
reason: str | None = None
|
||||
for evaluator in self.evaluators:
|
||||
eval_result = await evaluator(task_config, page, client, answer)
|
||||
score: float = score * eval_result["score"] # type: ignore
|
||||
if "reason" in eval_result:
|
||||
if reason is None:
|
||||
reason = eval_result["reason"] # type: ignore
|
||||
else:
|
||||
reason += f"\n{eval_result['reason']}"
|
||||
return {"score": score, "reason": reason} # type: ignore
|
||||
|
||||
|
||||
def evaluator_router(task_config: dict[str, Any]) -> EvaluatorComb:
|
||||
"""Creates and configures a composite evaluator based on the evaluation types specified in the configuration file.
|
||||
|
||||
Parameters:
|
||||
task_config dict[str, Any]: configuration specifying the evaluation types to use.
|
||||
|
||||
Returns:
|
||||
EvaluatorComb: A composite evaluator configured with the specified types of individual evaluators.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported evaluation type is specified in the configuration file.
|
||||
"""
|
||||
|
||||
eval_types = task_config["eval"]["eval_types"]
|
||||
evaluators: list[Evaluator] = []
|
||||
for eval_type in eval_types:
|
||||
match eval_type:
|
||||
case "string_match":
|
||||
logger.info("Adding string evaluator")
|
||||
evaluators.append(StringEvaluator())
|
||||
case "url_match":
|
||||
logger.info("Adding URL evaluator")
|
||||
evaluators.append(URLEvaluator())
|
||||
case "program_html":
|
||||
logger.info("Adding HTML evaluator")
|
||||
evaluators.append(HTMLContentEvaluator())
|
||||
case "manual":
|
||||
logger.info("Adding manual evaluator")
|
||||
evaluators.append(ManualContentEvaluator())
|
||||
case _:
|
||||
raise ValueError(f"eval_type {eval_type} is not supported")
|
||||
|
||||
return EvaluatorComb(evaluators)
|
263
Agent_E/test/test_utils.py
Normal file
263
Agent_E/test/test_utils.py
Normal file
|
@ -0,0 +1,263 @@
|
|||
"""Implements helper functions to assist evaluation cases where other evaluators are not suitable."""
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from nltk.tokenize import word_tokenize # type: ignore
|
||||
from openai import OpenAI
|
||||
|
||||
load_dotenv()
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def llm_fuzzy_match(pred: str, reference: str, question: str) -> float:
|
||||
"""
|
||||
Evaluates if a predicted answer matches a reference answer semantically, considering the context of a question.
|
||||
|
||||
This function simulates a grading scenario, understanding that a student's answer may use different wording or phrasing from the reference answer. It uses GPT-4-turbo model to assess semantic equivalence.
|
||||
|
||||
Parameters:
|
||||
pred (str): The student's predicted answer.
|
||||
reference (str): The reference answer to compare against.
|
||||
question (str): The question related to the answers.
|
||||
|
||||
Returns:
|
||||
float: Returns 1.0 if the predicted answer is semantically equivalent to the reference, otherwise 0.0.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
# construct the question to ask
|
||||
message = "Help a teacher to grade the answer of a student given a question. Keep in mind that the student may use different phrasing or wording to answer the question. The goal is to evaluate whether the answer is semantically equivalent to the reference answer.\n"
|
||||
message += f"question: {question}\n"
|
||||
message += f"reference answer: {reference}\n"
|
||||
message += "all the string 'N/A' that you see is a special sequence that means 'not achievable'\n"
|
||||
message += f"student answer: {pred}\n"
|
||||
message += "Conclude the judgement by correct/incorrect/partially correct."
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
|
||||
response = generate_from_openai_chat_completion(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=768,
|
||||
top_p=1.0,
|
||||
context_length=0,
|
||||
).lower()
|
||||
if "partially correct" in response or "incorrect" in response:
|
||||
return 0.0
|
||||
else:
|
||||
assert "correct" in response
|
||||
return 1.0
|
||||
|
||||
|
||||
def llm_ua_match(pred: str, reference: str, question: str) -> float:
|
||||
"""
|
||||
Evaluates the alignment between a reported reason for a task being unachievable and the actual reason.
|
||||
|
||||
This function reviews both the actual and reported reasons for a task's unachievability within the context of the task.
|
||||
It assesses if the reported reason is implicitly or explicitly in line with the actual reason, using GPT-turbo model.
|
||||
|
||||
Parameters:
|
||||
pred (str): The reported unachievable reason by an individual.
|
||||
reference (str): The actual reason why the task is unachievable.
|
||||
question (str): The task in question.
|
||||
|
||||
Returns:
|
||||
float: Returns 1.0 if the reported reason aligns with the actual reason, otherwise 0.0.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
# construct the question to ask
|
||||
message = ""
|
||||
message += f"task: {question}\n"
|
||||
message += f"actual unachievable reason: {reference}\n"
|
||||
message += f"reported unachievable reason: {pred}\n"
|
||||
message += (
|
||||
"The task described above is inherently unachievable due to the reason specified under 'actual unachievable reason'. "
|
||||
"An individual previously attempted this task and was unable to complete it. They provided a reason for their failure, "
|
||||
"which is listed under 'reported unachievable reason'. Your role is to review both the actual and reported reasons. "
|
||||
"Determine if the reported reason aligns with the actual reason, even if implicitly. "
|
||||
"If the stated reason is in line with the actual reason, respond with 'same'. Otherwise, respond with 'different'."
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
|
||||
response = generate_from_openai_chat_completion(
|
||||
model="gpt-4-turbo-preview",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=768,
|
||||
top_p=1.0,
|
||||
context_length=0,
|
||||
).lower()
|
||||
if "different" in response:
|
||||
return 0.0
|
||||
else:
|
||||
assert "same" in response
|
||||
return 1.0
|
||||
|
||||
|
||||
|
||||
def generate_from_openai_chat_completion(
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
context_length: int,
|
||||
stop_token: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a response from OpenAI's chat completions based on a conversation constructed from a list of messages.
|
||||
|
||||
This function makes a call to the OpenAI API using specified parameters to control the generation.
|
||||
It requires an API key to be set in the environment variables.
|
||||
|
||||
Parameters:
|
||||
messages (list[dict[str, str]]): A list of messages to construct the conversation context.
|
||||
model (str): The model name to use for generating the completion.
|
||||
temperature (float): Sampling temperature for generation.
|
||||
max_tokens (int): Maximum number of tokens to generate.
|
||||
top_p (float): Nucleus sampling parameter controlling the size of the probability mass to sample from.
|
||||
context_length (int): The maximum number of tokens from `messages` to use for context.
|
||||
stop_token (str, optional): A token at which to stop generating further tokens.
|
||||
|
||||
Returns:
|
||||
str: The generated response as a string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'OPENAI_API_KEY' environment variable is not set.
|
||||
"""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||
)
|
||||
client.api_key = os.environ["OPENAI_API_KEY"]
|
||||
client.organization = os.environ.get("OPENAI_ORGANIZATION", "")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages, # type: ignore
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
n=1,
|
||||
stop=[stop_token] if stop_token else None,
|
||||
)
|
||||
answer: str = response.choices[0].message.content # type: ignore
|
||||
return answer
|
||||
|
||||
def clean_answer(answer: str) -> str:
|
||||
"""Cleans and preprocesses the answer string for evaluation.
|
||||
|
||||
Parameters:
|
||||
answer (str): The answer string to clean.
|
||||
|
||||
Returns:
|
||||
str: The cleaned and lowercased answer string.
|
||||
"""
|
||||
answer = answer.strip().strip('"').strip("'").lower()
|
||||
return answer
|
||||
|
||||
def evaluate_exact_match(ref: str, pred: str) -> float:
|
||||
"""Evaluates if the predicted answer exactly matches the reference answer.
|
||||
|
||||
Parameters:
|
||||
ref (str): The reference answer.
|
||||
pred (str): The predicted answer.
|
||||
|
||||
Returns:
|
||||
float: 1.0 if the answers match exactly, otherwise 0.0.
|
||||
"""
|
||||
return float(clean_answer(pred) == clean_answer(ref))
|
||||
|
||||
def evaluate_must_include(ref: str, pred: str, tokenize: bool = False) -> float:
|
||||
"""Checks if the predicted answer includes all phrases from the reference answer.
|
||||
|
||||
Parameters:
|
||||
ref (str): The reference answer containing phrases that must be included.
|
||||
pred (str): The predicted answer to be evaluated.
|
||||
tokenize (bool, optional): Tokenizes the answers before evaluation if True. Default is False.
|
||||
|
||||
Returns:
|
||||
float: 1.0 if all phrases are included, otherwise 0.0.
|
||||
"""
|
||||
clean_ref = clean_answer(ref)
|
||||
clean_pred = clean_answer(pred)
|
||||
if tokenize and len(clean_ref) == 1:
|
||||
return float(clean_ref in word_tokenize(clean_pred))
|
||||
else:
|
||||
return float(clean_ref in clean_pred)
|
||||
|
||||
def evaluate_fuzzy_match(ref: str, pred: str, intent: str) -> float:
|
||||
"""Evaluates if the predicted answer is semantically similar to the reference answer.
|
||||
|
||||
Uses a large language model to assess similarity based on the intent of the question.
|
||||
|
||||
Parameters:
|
||||
ref (str): The reference answer.
|
||||
pred (str): The predicted answer.
|
||||
intent (str): The intent or context of the question.
|
||||
|
||||
Returns:
|
||||
float: 1.0 if the answers are considered semantically similar, otherwise 0.0.
|
||||
"""
|
||||
return llm_fuzzy_match(pred, ref, intent)
|
||||
|
||||
def evaluate_ua_match(ref: str, pred: str, intent: str) -> float:
|
||||
"""Evaluates if the predicted reason for a task being unachievable matches the reference reason.
|
||||
|
||||
Parameters:
|
||||
ref (str): The reference reason why the task is unachievable.
|
||||
pred (str): The predicted reason reported by the model.
|
||||
intent (str): The intent or context of the task.
|
||||
|
||||
Returns:
|
||||
float: 1.0 if the reasons match, otherwise 0.0.
|
||||
"""
|
||||
return llm_ua_match(pred, ref, intent)
|
||||
|
||||
|
||||
def load_config(config_file: Path | str) -> list[dict[str, Any]]:
|
||||
"""Load the confiufiguration for the test cases
|
||||
|
||||
Args:
|
||||
config_file (Path | str): Path to the config file
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: All the test cases in the config file
|
||||
"""
|
||||
with open(config_file, "r") as f: # noqa: UP015
|
||||
configs = json.load(f)
|
||||
return configs
|
||||
|
||||
def task_config_validator(task_config: dict[str, Any]) -> bool:
|
||||
# Access the attributes
|
||||
command = task_config.get('intent')
|
||||
|
||||
if not command:
|
||||
raise ValueError("Intent is missing in the task config file. Without it the task cannot be run.")
|
||||
|
||||
return True
|
||||
|
||||
def get_formatted_current_timestamp(format: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""Get the current timestamp in the specified format.
|
||||
|
||||
Args:
|
||||
format (str, optional): The format of the timestamp. Defaults to "%Y-%m-%d %H:%M:%S".
|
||||
|
||||
Returns:
|
||||
str: The current timestamp in the specified format.
|
||||
"""
|
||||
# Get the current time
|
||||
current_time = datetime.now()
|
||||
|
||||
# Format the timestamp as a human-readable string
|
||||
timestamp_str = current_time.strftime(format)
|
||||
return timestamp_str
|
409
Agent_E/test/tests_processor.py
Normal file
409
Agent_E/test/tests_processor.py
Normal file
|
@ -0,0 +1,409 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from Agent_E.ae.core.agents_llm_config import AgentsLLMConfig
|
||||
from Agent_E.test.test_utils import get_formatted_current_timestamp
|
||||
from Agent_E.test.test_utils import load_config
|
||||
from Agent_E.test.test_utils import task_config_validator
|
||||
from typing import Any
|
||||
|
||||
import Agent_E.ae.core.playwright_manager as browserManager
|
||||
import nltk # type: ignore
|
||||
from Agent_E.ae.config import PROJECT_TEST_ROOT
|
||||
from Agent_E.ae.core.autogen_wrapper import AutogenWrapper
|
||||
from Agent_E.ae.core.playwright_manager import PlaywrightManager
|
||||
from Agent_E.ae.utils.logger import logger
|
||||
from Agent_E.ae.utils.response_parser import parse_response
|
||||
from autogen.agentchat.chat import ChatResult # type: ignore
|
||||
from playwright.async_api import Page
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
from evaluation_harness.evaluators import evaluator_router
|
||||
|
||||
nltk.download('punkt') # type: ignore
|
||||
|
||||
last_agent_response = ""
|
||||
|
||||
def check_top_level_test_folders(test_log_dir, test_result_dir):
|
||||
if not os.path.exists(test_log_dir):
|
||||
os.makedirs(test_log_dir)
|
||||
logger.info(f"Created log folder at: {test_log_dir}")
|
||||
|
||||
if not os.path.exists(test_result_dir):
|
||||
os.makedirs(test_result_dir)
|
||||
logger.info(f"Created scores folder at: {test_result_dir}")
|
||||
|
||||
def create_task_log_folders(test_log_dir, task_id):
|
||||
task_log_dir = os.path.join(test_log_dir, task_id)
|
||||
task_screenshots_dir = os.path.join(task_log_dir, 'snapshots')
|
||||
if not os.path.exists(task_log_dir):
|
||||
os.makedirs(task_log_dir)
|
||||
logger.info(f"Created log dir for task {task_id} at: {task_log_dir}")
|
||||
if not os.path.exists(task_screenshots_dir):
|
||||
os.makedirs(task_screenshots_dir)
|
||||
logger.info(f"Created screenshots dir for task {task_id} at: {task_screenshots_dir}")
|
||||
|
||||
return {"task_log_folder": task_log_dir, "task_screenshots_folder": task_screenshots_dir}
|
||||
|
||||
|
||||
def create_results_dir(test_file: str, test_results_id: str|None) -> str:
|
||||
results_dir = ""
|
||||
if test_results_id:
|
||||
results_dir = os.path.join(TEST_RESULTS, f"results_for_{test_results_id}")
|
||||
else:
|
||||
test_file_base = os.path.basename(test_file)
|
||||
test_file_name = os.path.splitext(test_file_base)[0]
|
||||
results_dir = os.path.join(TEST_RESULTS, f"results_for_test_file_{test_file_name}")
|
||||
|
||||
if not os.path.exists(results_dir):
|
||||
os.makedirs(results_dir)
|
||||
logger.info(f"Created results directory: {results_dir}")
|
||||
|
||||
return results_dir
|
||||
|
||||
|
||||
def dump_log(task_id: str, messages_str_keys: dict[str, str], logs_dir: str):
|
||||
file_name = os.path.join(logs_dir, f'execution_logs_{task_id}.json')
|
||||
with open(file_name, 'w', encoding='utf-8') as f:
|
||||
json.dump(messages_str_keys, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_test_results(test_results: list[dict[str, str | int | float | None]], test_results_id: str):
|
||||
file_name = os.path.join(TEST_RESULTS, f'test_results_{test_results_id}.json')
|
||||
with open(file_name, 'w', encoding='utf-8') as f:
|
||||
json.dump(test_results, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"Test results dumped to: {file_name}")
|
||||
|
||||
|
||||
def save_individual_test_result(test_result: dict[str, str | int | float | None], results_dir: str):
|
||||
task_id = test_result["task_id"]
|
||||
file_name = os.path.join(results_dir, f'{task_id}.json')
|
||||
with open(file_name, 'w', encoding='utf-8') as f:
|
||||
json.dump(test_result, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"Test result for task {task_id} dumped to: {file_name}")
|
||||
|
||||
|
||||
def extract_last_response(messages: list[dict[str, Any]]) -> str:
|
||||
"""Extract the last response message from chat history."""
|
||||
try:
|
||||
# Iterate over the messages in reverse order
|
||||
for message in reversed(messages):
|
||||
if message and 'content' in message:
|
||||
content=message.get('content', "")
|
||||
content_json = parse_response(content)
|
||||
final_answer = content_json.get('final_response', None)
|
||||
if final_answer:
|
||||
return final_answer
|
||||
return ""
|
||||
except:
|
||||
logger.error("Error extracting last response from chat history.")
|
||||
return ""
|
||||
|
||||
|
||||
def print_progress_bar(current: int, total: int, bar_length: int = 50) -> None:
|
||||
"""
|
||||
Prints a progress bar to the console.
|
||||
|
||||
Parameters:
|
||||
- current (int): The current progress of the task.
|
||||
- total (int): The total number of tasks to complete.
|
||||
- bar_length (int): The character length of the progress bar (default is 50).
|
||||
|
||||
This function dynamically updates a single line in the console to reflect current progress.
|
||||
|
||||
"""
|
||||
percent = float(current) * 100 / total
|
||||
arrow = '-' * int(percent/100 * bar_length - 1) + '>'
|
||||
spaces = ' ' * (bar_length - len(arrow))
|
||||
|
||||
print(f'\rProgress: [{arrow}{spaces}] {current}/{total} ({percent:.2f}%)', end='')
|
||||
|
||||
def determine_status_and_color(score: float) -> tuple[str, str]:
|
||||
"""
|
||||
Determines the status and color for a test result based on the score.
|
||||
|
||||
Parameters:
|
||||
- score (float): The score of the test task, indicating success (1), failure (0), or skip (-0.1).
|
||||
|
||||
Returns:
|
||||
- tuple[str, str]: A tuple containing the status ('Pass', 'Fail', or 'Skip') and the corresponding color ('green', 'red', or 'yellow').
|
||||
|
||||
"""
|
||||
if score == 1:
|
||||
return 'Pass', 'green'
|
||||
elif score < 0:
|
||||
return 'Skip', 'yellow'
|
||||
else:
|
||||
return 'Fail', 'red'
|
||||
|
||||
|
||||
def print_test_result(task_result: dict[str, str | int | float | None], index: int, total: int) -> None:
|
||||
"""
|
||||
Prints the result of a single test task in a tabulated format.
|
||||
|
||||
Parameters:
|
||||
- task_result (dict): A dictionary containing the task's evaluation results, including task ID, intent, score, and total command time.
|
||||
- index (int): The current index of the test in the sequence of all tests being run.
|
||||
- total (int): The total number of tests to be run.
|
||||
|
||||
The function determines the test status (Pass/Fail) based on the 'score' key in task_result and prints the result with colored status.
|
||||
|
||||
"""
|
||||
status, color = determine_status_and_color(task_result['score']) # type: ignore
|
||||
|
||||
cost = task_result.get("compute_cost", None)
|
||||
total_cost = None if cost is None else round(cost.get("cost", -1), 4) # type: ignore
|
||||
total_tokens = None if cost is None else cost.get("total_tokens", -1) # type: ignore
|
||||
result_table = [ # type: ignore
|
||||
['Test Index', 'Task ID', 'Intent', 'Status', 'Time Taken (s)', 'Total Tokens', 'Total Cost ($)'],
|
||||
[index, task_result['task_id'], task_result['intent'], colored(status, color), round(task_result['tct'], 2), total_tokens, total_cost] # type: ignore
|
||||
]
|
||||
print('\n' + tabulate(result_table, headers='firstrow', tablefmt='grid')) # type: ignore
|
||||
|
||||
def get_command_exec_cost(command_exec_result: ChatResult):
|
||||
output: dict[str, Any] = {}
|
||||
try:
|
||||
cost = command_exec_result.cost # type: ignore
|
||||
usage: dict[str, Any] = None
|
||||
if "usage_including_cached_inference" in cost:
|
||||
usage: dict[str, Any] = cost["usage_including_cached_inference"]
|
||||
elif "usage_excluding_cached_inference" in cost:
|
||||
usage: dict[str, Any] = cost["usage_excluding_cached_inference"]
|
||||
else:
|
||||
raise ValueError("Cost not found in the command execution result.")
|
||||
print("Usage: ", usage)
|
||||
|
||||
for key in usage.keys():
|
||||
if isinstance(usage[key], dict) and "prompt_tokens" in usage[key]:
|
||||
output["cost"] = usage[key]["cost"]
|
||||
output["prompt_tokens"] = usage[key]["prompt_tokens"]
|
||||
output["completion_tokens"] = usage[key]["completion_tokens"]
|
||||
output["total_tokens"] = usage[key]["total_tokens"]
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting command execution cost: {e}")
|
||||
return output
|
||||
|
||||
|
||||
async def execute_single_task(task_config_file: str, browser_manager: PlaywrightManager, ag: AutogenWrapper, page: Page, logs_dir: str) -> dict[str, Any]:
|
||||
"""
|
||||
Executes a single test task based on a specified task configuration and evaluates its performance.
|
||||
|
||||
Parameters:
|
||||
- task_config (dict): The task configuration dictionary containing all necessary parameters for the task.
|
||||
- browser_manager (PlaywrightManager): The manager handling browser interactions, responsible for page navigation and control.
|
||||
- ag (AutogenWrapper): The automation generator wrapper that processes commands and interacts with the web page.
|
||||
- page (Page): The Playwright page object representing the browser tab where the task is executed.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing the task's evaluation results, including task ID, intent, score, total command time (tct),
|
||||
the last statement from the chat agent, and the last URL accessed during the task.
|
||||
"""
|
||||
command = ""
|
||||
start_url = None
|
||||
task_id = None
|
||||
|
||||
start_ts = get_formatted_current_timestamp()
|
||||
|
||||
task_config = json.load(open(task_config_file, "r"))
|
||||
|
||||
task_config_validator(task_config)
|
||||
|
||||
command: str = task_config.get('intent', "")
|
||||
task_id = task_config.get('task_id')
|
||||
task_index = task_config.get('task_index')
|
||||
start_url = task_config.get('start_url')
|
||||
logger.info(f"Intent: {command}, Task ID: {task_id}")
|
||||
|
||||
if start_url:
|
||||
await page.goto(start_url, wait_until='load', timeout=30000)
|
||||
|
||||
start_time = time.time()
|
||||
current_url = await browser_manager.get_current_url()
|
||||
command_exec_result = await ag.process_command(command, current_url)
|
||||
end_time = time.time()
|
||||
|
||||
evaluator_result: dict[str, float | str] = {}
|
||||
last_agent_response: str = ""
|
||||
command_cost: dict[str, Any] = {}
|
||||
single_task_result: dict[str, Any] = {}
|
||||
try:
|
||||
single_task_result = {
|
||||
"task_id": task_id,
|
||||
"task_index": task_index,
|
||||
"start_url": start_url,
|
||||
"intent": str(command),
|
||||
"last_url": page.url,
|
||||
"tct": end_time - start_time,
|
||||
"start_ts": start_ts,
|
||||
"completion_ts": get_formatted_current_timestamp()
|
||||
}
|
||||
|
||||
agent_name: str = "planner_agent" if ag.agents_map is not None and "planner_agent" in ag.agents_map else "browser_nav_agent"
|
||||
|
||||
command_cost = get_command_exec_cost(command_exec_result) # type: ignore
|
||||
print(f"Command cost: {command_cost}")
|
||||
single_task_result["compute_cost"] = command_cost
|
||||
|
||||
logger.info(f"Command \"{command}\" took: {round(end_time - start_time, 2)} seconds.")
|
||||
logger.info(f"Task {task_id} completed.")
|
||||
|
||||
messages = ag.agents_map[agent_name].chat_messages # type: ignore
|
||||
messages_str_keys = {str(key): value for key, value in messages.items()} # type: ignore
|
||||
agent_key = list(messages.keys())[0] # type: ignore
|
||||
last_agent_response = extract_last_response(messages[agent_key]) # type: ignore
|
||||
|
||||
dump_log(str(task_id), messages_str_keys, logs_dir)
|
||||
|
||||
single_task_result["last_statement"] = last_agent_response
|
||||
|
||||
|
||||
evaluator = evaluator_router(task_config_file)
|
||||
# cdp_session = await page.context.new_cdp_session(page)
|
||||
evaluator_result = evaluator(
|
||||
config_file=task_config_file,
|
||||
page=None,
|
||||
client=None,
|
||||
trajectory=[{"answer":last_agent_response}]
|
||||
)
|
||||
|
||||
single_task_result["score"] = evaluator_result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting command cost: {e}")
|
||||
command_cost = {"cost": -1, "total_tokens": -1}
|
||||
single_task_result["compute_cost"] = command_cost
|
||||
single_task_result["error"] = str(e)
|
||||
|
||||
return single_task_result
|
||||
|
||||
|
||||
async def run_tests(ag: AutogenWrapper, browser_manager: PlaywrightManager, task_ids,
|
||||
logdir: str="", logname: str="", relative_task_dir: str="", test_results_id: str = "", wait_time_non_headless: int=5, take_screenshots: bool = False) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Runs a specified range of test tasks using Playwright for browser interactions and AutogenWrapper for task automation.
|
||||
It initializes necessary components, processes each task, handles exceptions, and compiles test results into a structured list.
|
||||
|
||||
Parameters:
|
||||
- ag (AutogenWrapper): The AutoGen wrapper that processes commands.
|
||||
- browser_manager (PlaywrightManager): The manager handling browser interactions, responsible for page navigation and control.
|
||||
- logdir (str)
|
||||
- logname (str)
|
||||
- task_ids (List[str])
|
||||
- relative_task_dir (str)
|
||||
- wait_time_non_headless (int): Time to wait between tasks when running in non-headless mode, useful for live monitoring or debugging.
|
||||
- take_screenshots (bool): Whether to take screenshots during test execution. Defaults to False.
|
||||
|
||||
Returns:
|
||||
- list[dict[str, Any]]: A list of dictionaries, each containing the results from executing a test task. Results include task ID, intent, score, total command time, etc.
|
||||
|
||||
This function also manages logging and saving of test results, updates the progress bar to reflect test execution status, and prints a detailed summary report at the end of the testing session.
|
||||
"""
|
||||
test_log_dir = os.path.join(logdir, logname)
|
||||
test_result_dir = os.path.join(logdir, logname, "results")
|
||||
check_top_level_test_folders(test_log_dir, test_result_dir)
|
||||
|
||||
config_file_list = []
|
||||
if not relative_task_dir or relative_task_dir == "":
|
||||
relative_task_dir = "tasks"
|
||||
if task_ids == "all" or task_ids == ["all"]:
|
||||
task_ids = [filename[:-len(".json")] for filename in os.listdir(f"config_files/{relative_task_dir}") if filename.endswith(".json")]
|
||||
for task_id in task_ids:
|
||||
config_file_list.append(f"config_files/{relative_task_dir}/{task_id}.json")
|
||||
|
||||
test_results: list[dict[str, str | int | float | None]] = []
|
||||
|
||||
llm_config = AgentsLLMConfig()
|
||||
if not ag:
|
||||
ag = await AutogenWrapper.create(llm_config.get_planner_agent_config(), llm_config.get_browser_nav_agent_config())
|
||||
|
||||
if not browser_manager:
|
||||
browser_manager = browserManager.PlaywrightManager(headless=True)
|
||||
await browser_manager.async_initialize()
|
||||
|
||||
page=await browser_manager.get_current_page()
|
||||
test_results = []
|
||||
total_tests = len(config_file_list)
|
||||
|
||||
for index, task_config_file in enumerate(config_file_list):
|
||||
task_config = json.load(open(task_config_file, "r"))
|
||||
task_id = str(task_config.get('task_id'))
|
||||
if os.path.exists(os.path.join(test_result_dir, f'{task_id}.json')):
|
||||
continue
|
||||
|
||||
log_folders = create_task_log_folders(test_log_dir, task_id)
|
||||
|
||||
ag.set_chat_logs_dir(log_folders["task_log_folder"])
|
||||
|
||||
browser_manager.set_take_screenshots(take_screenshots)
|
||||
if take_screenshots:
|
||||
browser_manager.set_screenshots_dir(log_folders["task_screenshots_folder"])
|
||||
|
||||
print_progress_bar(index, total_tests)
|
||||
task_result = await execute_single_task(task_config_file, browser_manager, ag, page, log_folders["task_log_folder"])
|
||||
test_results.append(task_result)
|
||||
save_individual_test_result(task_result, test_result_dir)
|
||||
print_test_result(task_result, index + 1, total_tests)
|
||||
|
||||
if not browser_manager.isheadless: # no need to wait if we are running headless
|
||||
await asyncio.sleep(wait_time_non_headless) # give time for switching between tasks in case there is a human observer
|
||||
|
||||
await browser_manager.take_screenshots("final", None)
|
||||
|
||||
await browser_manager.close_except_specified_tab(page) # cleanup pages that are not the one we opened here
|
||||
|
||||
print_progress_bar(total_tests, total_tests) # Complete the progress bar
|
||||
print('\n\nAll tests completed.')
|
||||
|
||||
# Aggregate and print individual test results
|
||||
print("\nDetailed Test Results:")
|
||||
detailed_results_table = [['Test Index', 'Task ID', 'Intent', 'Status', 'Time Taken (s)', 'Total Tokens', 'Total Cost ($)']]
|
||||
for idx, result in enumerate(test_results, 1):
|
||||
status, color = determine_status_and_color(result['score']) # type: ignore
|
||||
|
||||
cost: str | int | float | None = result.get("compute_cost", None)
|
||||
total_cost = None if cost is None else round(cost.get("cost", -1), 4) # type: ignore
|
||||
total_tokens = None if cost is None else cost.get("total_tokens", -1) # type: ignore
|
||||
|
||||
detailed_results_table.append([
|
||||
idx, result['task_id'], result['intent'], colored(status, color), round(result['tct'], 2), # type: ignore
|
||||
total_tokens, total_cost
|
||||
])
|
||||
|
||||
print(tabulate(detailed_results_table, headers='firstrow', tablefmt='grid'))
|
||||
|
||||
# Summary report
|
||||
|
||||
# Calculate aggregated cost and token totals for all tests that have compute cost
|
||||
total_cost = 0
|
||||
total_tokens = 0
|
||||
|
||||
for result in test_results:
|
||||
compute_cost = result.get("compute_cost",0) # type: ignore
|
||||
if compute_cost is not None and isinstance(compute_cost, dict):
|
||||
total_cost += compute_cost.get("cost", 0) # type: ignore
|
||||
total_tokens += compute_cost.get("total_tokens", 0) # type: ignore
|
||||
|
||||
passed_tests = []
|
||||
skipped_tests = []
|
||||
failed_tests = []
|
||||
for result in test_results:
|
||||
if result["score"] == 1:
|
||||
passed_tests.append(result) # type: ignore
|
||||
elif result["score"] < 0: # type: ignore
|
||||
skipped_tests.append(result) # type: ignore
|
||||
else:
|
||||
failed_tests.append(result) # type: ignore
|
||||
|
||||
summary_table = [ # type: ignore
|
||||
['Total Tests', 'Passed', 'Failed', 'Skipped', 'Average Time Taken (s)', 'Total Time Taken (s)', 'Total Tokens', 'Total Cost ($)'],
|
||||
[total_tests, len(passed_tests), len(failed_tests), len(skipped_tests),
|
||||
round(sum(test['tct'] for test in test_results) / total_tests, 2), # type: ignore
|
||||
round(sum(test['tct'] for test in test_results), 2), # type: ignore
|
||||
total_tokens, total_cost]
|
||||
]
|
||||
|
||||
print('\nSummary Report:')
|
||||
print(tabulate(summary_table, headers='firstrow', tablefmt='grid')) # type: ignore
|
||||
|
||||
return test_results
|
228
README.md
228
README.md
|
@ -1,17 +1,227 @@
|
|||
## My Project
|
||||
# AgentOccam
|
||||
Code for "[AgentOccam: A Simple Yet Strong Baseline for LLM-Based Web Agents]()".
|
||||
|
||||
TODO: Fill this README out!
|
||||

|
||||
|
||||
Be sure to:
|
||||
We work on automating web tasks! 🏄🏄🏄 We refine the LLM-based web agents by aligning their observation and action space with the capabilities of LLMs.
|
||||
|
||||
* Change the title in this README
|
||||
* Edit your repository description on GitHub
|
||||
The newly designed agent AgentOccam surpasses previous state-of-the-art methods and concurrent work significantly w/o in-context examples, new agent roles, online feedback or search strategies on [WebArena](https://webarena.dev), a benchmark featuring general-purpose web tasks. 🍺
|
||||
|
||||
## Security
|
||||
We shed light on LLMs' impressive zero-shot performance on web tasks, and the critical role of carefully tuning observation and action spaces for LLM-based agents. 🧙
|
||||
|
||||
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
|
||||
You can let AgentOccam interact with other websites like Google per your requests by defining the task config files, as seen in the example in `config_files/tasks/standford_cs_head.json`. Have fun playing with it! :)
|
||||
|
||||
## License
|
||||
*Please check whether reddit post exceeds limits, login expires, or any other webarena simulator/website failure exists when you finish one round. You should restart the simluator/relogin to the websites and rerun those tasks before reporting your final success rate. Additionally, LLM policy varies even given the same task as the generation temperature is set to >0 for more diverse exploration. Therefore, it is expected that you can get difference traces when starting the same task multiple times. Try it out with the basic `config_files/tasks/standford_cs_head.json`!*
|
||||
|
||||
This project is licensed under the Apache-2.0 License.
|
||||
## WebArena Replication
|
||||
### Environment Setup
|
||||
```bash
|
||||
git clone https://github.com/web-arena-x/webarena.git
|
||||
cd webarena
|
||||
conda create -n webarena python=3.10; conda activate webarena
|
||||
pip install -r requirements.txt
|
||||
pip install --upgrade transformers
|
||||
pip install --upgrade openai
|
||||
pip install numpy==1.26.4
|
||||
playwright install
|
||||
pip install -e .
|
||||
cd ../AgentOccam
|
||||
pip install -r requirements.txt
|
||||
mkdir .auth
|
||||
```
|
||||
|
||||
### Experiments
|
||||
#### AgentOccam-Series and SteP-Replication
|
||||
* Connect to the WebArena host server.
|
||||
* Export the env configs:
|
||||
```bash
|
||||
export SHOPPING="http://<webarena_server_address>:7770"
|
||||
export SHOPPING_ADMIN="http://<webarena_server_address>:7780/admin"
|
||||
export REDDIT="http://<webarena_server_address>:9999"
|
||||
export GITLAB="http://<webarena_server_address>:8023"
|
||||
export MAP="http://<webarena_server_address>:3000"
|
||||
export WIKIPEDIA="http://<webarena_server_address>:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing"
|
||||
export HOMEPAGE="http://<webarena_server_address>:4399"
|
||||
export OPENAI_API_KEY="<openai_api_key>"
|
||||
export GEMINI_API_KEY="<gemini_api_key>"
|
||||
```
|
||||
* Login in:
|
||||
```bash
|
||||
python browser_env/auto_login.py
|
||||
```
|
||||
* Test AgentOccam:
|
||||
```bash
|
||||
python eval_webarena.py --config AgentOccam/configs/AgentOccam.yml # Replace the yml config with your target one.
|
||||
```
|
||||
*You can use directly run `bash script/run_config.sh` after replacing the experiment configurations.*
|
||||
#### WebArena-Replication
|
||||
```bash
|
||||
bash scripts/run_webarena.sh
|
||||
```
|
||||
|
||||
## WebVoyager Replication
|
||||
### Environment Setup
|
||||
```bash
|
||||
git clone https://github.com/EmergenceAI/Agent-E.git
|
||||
cd Agent-E
|
||||
./install.sh
|
||||
source .venv/bin/activate
|
||||
uv pip install beartype
|
||||
uv pip install gymnasium
|
||||
uv pip install lxml
|
||||
uv pip install text_generation
|
||||
uv pip install aiolimiter
|
||||
uv pip install boto3
|
||||
uv pip install transformers
|
||||
export OPENAI_API_KEY="<openai_api_key>"
|
||||
export AUTOGEN_MODEL_NAME="gpt-4-turbo"
|
||||
cd ../AgentOccam
|
||||
```
|
||||
### Experiments
|
||||
#### AgentOccam
|
||||
```bash
|
||||
python eval_webarena.py --config AgentOccam/configs/AgentOccam-WebVoyager.yml
|
||||
```
|
||||
#### Agent-E
|
||||
```bash
|
||||
python -m agente_replication --task_ids Allrecipes--3
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Agent Configuration Explanation
|
||||
|
||||
They following are compiled based on `AgentOccam/configs/AgentOccam.yml`.
|
||||
|
||||
### General
|
||||
|
||||
```yml
|
||||
logdir: "../AgentOccam-Trajectories"
|
||||
```
|
||||
|
||||
This determines where the trajectories will be saved. Use relative path.
|
||||
|
||||
```yml
|
||||
logname: "AgentOccam"
|
||||
agent:
|
||||
others:
|
||||
logname: "AgentOccam"
|
||||
```
|
||||
|
||||
All relevant online files (play series, trash series, and output/screenshot series) will use this log name to differentiate. Change them simultaneously.
|
||||
|
||||
### Agent
|
||||
#### Base
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
debug: 0
|
||||
verbose: 1
|
||||
number: 1
|
||||
critic:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
judge:
|
||||
mode: false
|
||||
debug: 0
|
||||
verbose: 1
|
||||
```
|
||||
|
||||
All roles have a `debug` key. When `debug==1`, it plays and you decide whether to take its action. When `debug==2`, you will have to generate the action yourself. The actor is always playing so there's no `mode` key for it. For other roles, you can disable them by changing `mode` to false.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
model: "gpt-4-turbo"
|
||||
```
|
||||
|
||||
determines which model to use.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
input: ["step", "objective", "previous plans", "interaction history", "current observation"]
|
||||
```
|
||||
arranges the input. The list element order matters here and this applies to all the following list input/output specifications.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
interaction_history:
|
||||
verbose: True
|
||||
type: ["text"]
|
||||
step_num: 3
|
||||
```
|
||||
determines the interaction history section input type and modality. You can use `type: ["text", "image"]` to enable multimodality inputs.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
current_observation:
|
||||
type: ["text"]
|
||||
```
|
||||
defines the current observation type.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
output: ["interaction history summary", "observation description", "reason", "action", "observation highlight"]
|
||||
```
|
||||
organize the output specifications, and capable LLMs should generate those content, which would be parsed automatically by the code. You only need to add the description for that entry under `AgentOccam/prompts/output_specifications`.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
planning_command: ["branch", "prune"]
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
```
|
||||
defines the valid actions.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
play: ["step", "objective", "previous plans", "observation description", "reason", "action"]
|
||||
trash: ["objective", "step", "url", "instruction", "online input", "response", "alter ego response"]
|
||||
```
|
||||
designates the broadcasting content.
|
||||
|
||||
#### Advanced
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
number: 1
|
||||
```
|
||||
If you use best-of-**N**-actions with judge, the `number` here defines the **N**.
|
||||
|
||||
```yml
|
||||
agent:
|
||||
actor:
|
||||
identities:
|
||||
identity_0:
|
||||
name: "QA"
|
||||
model: "gpt-4-turbo"
|
||||
output: ["response"]
|
||||
identity_1:
|
||||
name: "planning"
|
||||
model: "gpt-4-turbo"
|
||||
planning_command: ["branch", "prune"]
|
||||
output: ["interaction history summary", "observation description", "reason", "plan", "observation highlight"]
|
||||
identity_2:
|
||||
name: "reflection"
|
||||
model: "gpt-4-turbo"
|
||||
planning_command: ["branch", "prune"]
|
||||
navigation_command: ["click", "type", "stop", "note", "go_back"]
|
||||
output: ["interaction history summary", "observation description", "reflection", "reason", "action", "observation highlight"]
|
||||
```
|
||||
defines different actors. If you don't want them, comment them.
|
||||
|
||||
## Environment
|
||||
```yml
|
||||
env:
|
||||
fullpage: true
|
||||
prune: true
|
||||
```
|
||||
If `fullpage==True`, the agent takes the entire page as the input. Remember to add `scroll` to the `navigation_action` list if `fullpage` is disabled.
|
||||
|
||||
If `prune==True`, the pipeline carries out observation space alignment.
|
29
agente_replication.py
Normal file
29
agente_replication.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
from Agent_E.test.tests_processor import run_tests
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create the parser
|
||||
parser = argparse.ArgumentParser(description='Run test suite for specified range of test tasks.')
|
||||
|
||||
# Add arguments
|
||||
parser.add_argument('-s', '--take_screenshots', type=bool, default=False,
|
||||
help='Take screenshots after every operation performed (default: False)')
|
||||
parser.add_argument('-wait', '--wait_time_non_headless', type=int, default=5,
|
||||
help='Time to wait between test tasks when running in non-headless mode (default: 10 seconds)')
|
||||
parser.add_argument("-ids", "--task_ids", type=str, nargs='+', help="List of task IDs to execute")
|
||||
parser.add_argument('-dir', '--logdir', type=str, default="../AgentOccam-Trajectories",
|
||||
help='Logdir.')
|
||||
parser.add_argument('-log', '--logname', type=str, default="Agent-E",
|
||||
help='Logname.')
|
||||
parser.add_argument('-id', '--test_results_id', type=str, default="",
|
||||
help='A unique identifier for the test results. If not provided, a timestamp is used.')
|
||||
parser.add_argument('-config', '--relative_task_dir', type=str, default="webvoyager",
|
||||
help='Path to the test configuration file.')
|
||||
|
||||
# Parse the command line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run the main function with the provided or default arguments, not passing browser_manager or AutoGenWrapper will cause the test processor to create new instances of them
|
||||
asyncio.run(run_tests(None, None, args.task_ids, logdir=args.logdir, logname=args.logname, relative_task_dir=args.relative_task_dir,
|
||||
take_screenshots=args.take_screenshots, wait_time_non_headless=args.wait_time_non_headless))
|
78
browser_env/__init__.py
Normal file
78
browser_env/__init__.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
import asyncio
|
||||
|
||||
from .actions import (
|
||||
Action,
|
||||
ActionParsingError,
|
||||
ActionTypes,
|
||||
action2create_function,
|
||||
action2str,
|
||||
create_check_action,
|
||||
create_click_action,
|
||||
create_focus_and_click_action,
|
||||
create_focus_and_type_action,
|
||||
create_go_back_action,
|
||||
create_go_forward_action,
|
||||
create_goto_url_action,
|
||||
create_hover_action,
|
||||
create_id_based_action,
|
||||
create_id_based_actions,
|
||||
create_key_press_action,
|
||||
create_keyboard_type_action,
|
||||
create_mouse_click_action,
|
||||
create_mouse_hover_action,
|
||||
create_new_tab_action,
|
||||
create_none_action,
|
||||
create_page_close_action,
|
||||
create_page_focus_action,
|
||||
create_playwright_action,
|
||||
create_random_action,
|
||||
create_scroll_action,
|
||||
create_select_option_action,
|
||||
create_stop_action,
|
||||
create_type_action,
|
||||
is_equivalent,
|
||||
)
|
||||
from .async_envs import AsyncScriptBrowserEnv
|
||||
from .envs import ScriptBrowserEnv
|
||||
from .processors import ObservationMetadata
|
||||
from .trajectory import Trajectory
|
||||
from .utils import DetachedPage, StateInfo
|
||||
|
||||
__all__ = [
|
||||
"ScriptBrowserEnv",
|
||||
"AsyncScriptBrowserEnv",
|
||||
"DetachedPage",
|
||||
"StateInfo",
|
||||
"ObservationMetadata",
|
||||
"Action",
|
||||
"ActionTypes",
|
||||
"action2str",
|
||||
"create_random_action",
|
||||
"create_focus_and_click_action",
|
||||
"create_focus_and_type_action",
|
||||
"is_equivalent",
|
||||
"create_mouse_click_action",
|
||||
"create_mouse_hover_action",
|
||||
"create_none_action",
|
||||
"create_keyboard_type_action",
|
||||
"create_page_focus_action",
|
||||
"create_new_tab_action",
|
||||
"create_go_back_action",
|
||||
"create_go_forward_action",
|
||||
"create_goto_url_action",
|
||||
"create_page_close_action",
|
||||
"action2create_function",
|
||||
"create_playwright_action",
|
||||
"create_id_based_action",
|
||||
"create_id_based_actions",
|
||||
"create_scroll_action",
|
||||
"create_key_press_action",
|
||||
"create_check_action",
|
||||
"create_click_action",
|
||||
"create_type_action",
|
||||
"create_hover_action",
|
||||
"create_select_option_action",
|
||||
"create_stop_action",
|
||||
"ActionParsingError",
|
||||
"Trajectory",
|
||||
]
|
1948
browser_env/actions.py
Normal file
1948
browser_env/actions.py
Normal file
File diff suppressed because it is too large
Load Diff
153
browser_env/async_envs.py
Normal file
153
browser_env/async_envs.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Text
|
||||
from playwright.async_api import Page, ViewportSize, async_playwright
|
||||
|
||||
from .actions import Action, aexecute_action, get_action_space
|
||||
from .utils import DetachedPage, png_bytes_to_numpy
|
||||
|
||||
|
||||
class AsyncScriptBrowserEnv(Env[npt.NDArray[np.uint8], Action]):
|
||||
"""
|
||||
The goal of this environment is to produce a prototype of a browser environment.
|
||||
In the end, we want to support a fully configurable browser environment with wide
|
||||
range of action spaces and observation spaces, both structured and unstructured.
|
||||
But in this prototype, we just support action space specified by Playwright script,
|
||||
and observation space is the html content of the page.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_page_length: int = 2048,
|
||||
headless: bool = True,
|
||||
slow_mo: int = 0,
|
||||
timeout: int = 30000,
|
||||
viewport_size: ViewportSize = {"width": 1280, "height": 720},
|
||||
):
|
||||
self.observation_space = Box(
|
||||
0,
|
||||
255,
|
||||
(viewport_size["height"], viewport_size["width"], 4),
|
||||
np.uint8,
|
||||
)
|
||||
# TODO: make Space[Action] = ActionSpace
|
||||
self.action_space = get_action_space() # type: ignore[assignment]
|
||||
self.headless = headless
|
||||
self.slow_mo = slow_mo
|
||||
self.reset_finished = False
|
||||
self.timeout = timeout
|
||||
self.viewport_size = viewport_size
|
||||
|
||||
async def setup(self, config_file: Path | None = None) -> None:
|
||||
self.context_manager = async_playwright()
|
||||
self.playwright = await self.context_manager.__aenter__()
|
||||
self.browser = await self.playwright.chromium.launch(
|
||||
headless=self.headless, slow_mo=self.slow_mo
|
||||
)
|
||||
if config_file:
|
||||
with open(config_file, "r") as f:
|
||||
instance_config = json.load(f)
|
||||
else:
|
||||
instance_config = {}
|
||||
|
||||
storage_state = instance_config.get("storage_state", None)
|
||||
start_url = instance_config.get("start_url", None)
|
||||
geolocation = instance_config.get("geolocation", None)
|
||||
|
||||
self.context = await self.browser.new_context(
|
||||
viewport=self.viewport_size,
|
||||
storage_state=storage_state,
|
||||
geolocation=geolocation,
|
||||
device_scale_factor=1,
|
||||
)
|
||||
self.page = await self.context.new_page()
|
||||
if start_url:
|
||||
await self.page.goto(start_url)
|
||||
|
||||
async def areset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, str] | None = None,
|
||||
) -> tuple[npt.NDArray[np.uint8], dict[str, object]]:
|
||||
"""
|
||||
Reset the environment.
|
||||
:param options: options for the environment. The options are:
|
||||
- storage_state: the path to the storage state file
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
if self.reset_finished:
|
||||
await self.context_manager.__aexit__()
|
||||
if options is not None and "config_file" in options:
|
||||
config_file = Path(options["config_file"])
|
||||
if config_file.exists():
|
||||
await self.setup(config_file=config_file)
|
||||
else:
|
||||
raise ValueError(f"Config state {config_file} does not exist.")
|
||||
else:
|
||||
await self.setup()
|
||||
self.reset_finished = True
|
||||
content = await self.page.content()
|
||||
screenshot = png_bytes_to_numpy(await self.page.screenshot())
|
||||
return (
|
||||
screenshot,
|
||||
{"page": DetachedPage(self.page.url, content)},
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, str] | None = None,
|
||||
) -> tuple[npt.NDArray[np.uint8], dict[str, object]]:
|
||||
return asyncio.run(self.areset(seed=seed, options=options))
|
||||
|
||||
async def aclose(self) -> None:
|
||||
if self.reset_finished:
|
||||
await self.context_manager.__aexit__()
|
||||
|
||||
def close(self) -> None:
|
||||
asyncio.run(self.aclose())
|
||||
|
||||
async def astep(
|
||||
self, action: Action
|
||||
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
|
||||
if not self.reset_finished:
|
||||
raise RuntimeError("Call reset first before calling step.")
|
||||
success = False
|
||||
fail_error = ""
|
||||
try:
|
||||
self.page = await aexecute_action(action, self.page, self.context)
|
||||
success = True
|
||||
except Exception as e:
|
||||
fail_error = str(e)
|
||||
|
||||
try:
|
||||
content = await self.page.content()
|
||||
screenshot = png_bytes_to_numpy(await self.page.screenshot())
|
||||
except:
|
||||
await self.page.wait_for_load_state("load")
|
||||
content = await self.page.content()
|
||||
screenshot = png_bytes_to_numpy(await self.page.screenshot())
|
||||
|
||||
return (
|
||||
screenshot,
|
||||
float(success),
|
||||
False,
|
||||
False,
|
||||
{
|
||||
"page": DetachedPage(self.page.url, content),
|
||||
"fail_error": fail_error,
|
||||
},
|
||||
)
|
||||
|
||||
def step(
|
||||
self, action: Action
|
||||
) -> tuple[npt.NDArray[np.uint8], float, bool, bool, dict[str, object]]:
|
||||
return asyncio.run(self.astep(action), debug=True)
|
207
browser_env/auto_login.py
Normal file
207
browser_env/auto_login.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
"""Script to automatically login each website"""
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import combinations
|
||||
from pathlib import Path
|
||||
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from browser_env.env_config import (
|
||||
ACCOUNTS,
|
||||
GITLAB,
|
||||
REDDIT,
|
||||
SHOPPING,
|
||||
SHOPPING_ADMIN,
|
||||
)
|
||||
|
||||
HEADLESS = True
|
||||
SLOW_MO = 0
|
||||
|
||||
|
||||
SITES = ["gitlab", "shopping", "shopping_admin", "reddit"]
|
||||
URLS = [
|
||||
f"{GITLAB}/-/profile",
|
||||
f"{SHOPPING}/wishlist/",
|
||||
f"{SHOPPING_ADMIN}/dashboard",
|
||||
f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account",
|
||||
]
|
||||
EXACT_MATCH = [True, True, True, True]
|
||||
KEYWORDS = ["", "", "Dashboard", "Delete"]
|
||||
|
||||
|
||||
def is_expired(
|
||||
storage_state: Path, url: str, keyword: str, url_exact: bool = True
|
||||
) -> bool:
|
||||
"""Test whether the cookie is expired"""
|
||||
if not storage_state.exists():
|
||||
return True
|
||||
|
||||
context_manager = sync_playwright()
|
||||
playwright = context_manager.__enter__()
|
||||
browser = playwright.chromium.launch(headless=True, slow_mo=SLOW_MO)
|
||||
context = browser.new_context(storage_state=storage_state)
|
||||
page = context.new_page()
|
||||
page.goto(url)
|
||||
time.sleep(1)
|
||||
d_url = page.url
|
||||
content = page.content()
|
||||
context_manager.__exit__()
|
||||
if keyword:
|
||||
return keyword not in content
|
||||
else:
|
||||
if url_exact:
|
||||
return d_url != url
|
||||
else:
|
||||
return url not in d_url
|
||||
|
||||
|
||||
def renew_comb(comb: list[str], auth_folder: str = "./.auth") -> None:
|
||||
for c in comb:
|
||||
context_manager = sync_playwright()
|
||||
playwright = context_manager.__enter__()
|
||||
browser = playwright.chromium.launch(headless=HEADLESS)
|
||||
context = browser.new_context()
|
||||
page = context.new_page()
|
||||
|
||||
if c == "shopping":
|
||||
username = ACCOUNTS["shopping"]["username"]
|
||||
password = ACCOUNTS["shopping"]["password"]
|
||||
page.goto(f"{SHOPPING}/customer/account/login/")
|
||||
page.get_by_label("Email", exact=True).fill(username)
|
||||
page.get_by_label("Password", exact=True).fill(password)
|
||||
page.get_by_role("button", name="Sign In").click()
|
||||
|
||||
if c == "reddit":
|
||||
username = ACCOUNTS["reddit"]["username"]
|
||||
password = ACCOUNTS["reddit"]["password"]
|
||||
page.goto(f"{REDDIT}/login")
|
||||
page.get_by_label("Username").fill(username)
|
||||
page.get_by_label("Password").fill(password)
|
||||
page.get_by_role("button", name="Log in").click()
|
||||
|
||||
if c == "shopping_admin":
|
||||
username = ACCOUNTS["shopping_admin"]["username"]
|
||||
password = ACCOUNTS["shopping_admin"]["password"]
|
||||
page.goto(f"{SHOPPING_ADMIN}")
|
||||
page.get_by_placeholder("user name").fill(username)
|
||||
page.get_by_placeholder("password").fill(password)
|
||||
page.get_by_role("button", name="Sign in").click()
|
||||
|
||||
if c == "gitlab":
|
||||
username = ACCOUNTS["gitlab"]["username"]
|
||||
password = ACCOUNTS["gitlab"]["password"]
|
||||
page.goto(f"{GITLAB}/users/sign_in")
|
||||
page.screenshot(path="debug.png")
|
||||
page.get_by_test_id("username-field").click()
|
||||
page.get_by_test_id("username-field").fill(username)
|
||||
page.get_by_test_id("username-field").press("Tab")
|
||||
page.get_by_test_id("password-field").fill(password)
|
||||
page.get_by_test_id("sign-in-button").click()
|
||||
|
||||
context.storage_state(path=f"{auth_folder}/{c}_state.json")
|
||||
|
||||
context_manager.__exit__()
|
||||
context_manager = sync_playwright()
|
||||
playwright = context_manager.__enter__()
|
||||
browser = playwright.chromium.launch(headless=HEADLESS)
|
||||
context = browser.new_context()
|
||||
page = context.new_page()
|
||||
|
||||
if "shopping" in comb:
|
||||
username = ACCOUNTS["shopping"]["username"]
|
||||
password = ACCOUNTS["shopping"]["password"]
|
||||
page.goto(f"{SHOPPING}/customer/account/login/")
|
||||
page.get_by_label("Email", exact=True).fill(username)
|
||||
page.get_by_label("Password", exact=True).fill(password)
|
||||
page.get_by_role("button", name="Sign In").click()
|
||||
|
||||
if "reddit" in comb:
|
||||
username = ACCOUNTS["reddit"]["username"]
|
||||
password = ACCOUNTS["reddit"]["password"]
|
||||
page.goto(f"{REDDIT}/login")
|
||||
page.get_by_label("Username").fill(username)
|
||||
page.get_by_label("Password").fill(password)
|
||||
page.get_by_role("button", name="Log in").click()
|
||||
|
||||
if "shopping_admin" in comb:
|
||||
username = ACCOUNTS["shopping_admin"]["username"]
|
||||
password = ACCOUNTS["shopping_admin"]["password"]
|
||||
page.goto(f"{SHOPPING_ADMIN}")
|
||||
page.get_by_placeholder("user name").fill(username)
|
||||
page.get_by_placeholder("password").fill(password)
|
||||
page.get_by_role("button", name="Sign in").click()
|
||||
|
||||
if "gitlab" in comb:
|
||||
username = ACCOUNTS["gitlab"]["username"]
|
||||
password = ACCOUNTS["gitlab"]["password"]
|
||||
page.goto(f"{GITLAB}/users/sign_in")
|
||||
page.get_by_test_id("username-field").click()
|
||||
page.get_by_test_id("username-field").fill(username)
|
||||
page.get_by_test_id("username-field").press("Tab")
|
||||
page.get_by_test_id("password-field").fill(password)
|
||||
page.get_by_test_id("sign-in-button").click()
|
||||
|
||||
context.storage_state(path=f"{auth_folder}/{'.'.join(comb)}_state.json")
|
||||
|
||||
context_manager.__exit__()
|
||||
|
||||
|
||||
def get_site_comb_from_filepath(file_path: str) -> list[str]:
|
||||
comb = os.path.basename(file_path).rsplit("_", 1)[0].split(".")
|
||||
return comb
|
||||
|
||||
|
||||
def main(auth_folder: str = "./.auth") -> None:
|
||||
pairs = list(combinations(SITES, 2))
|
||||
|
||||
max_workers = 8
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for pair in pairs:
|
||||
# TODO[shuyanzh] auth don't work on these two sites
|
||||
if "reddit" in pair and (
|
||||
"shopping" in pair or "shopping_admin" in pair
|
||||
):
|
||||
continue
|
||||
executor.submit(
|
||||
renew_comb, list(sorted(pair)), auth_folder=auth_folder
|
||||
)
|
||||
|
||||
for site in SITES:
|
||||
executor.submit(renew_comb, [site], auth_folder=auth_folder)
|
||||
|
||||
futures = []
|
||||
cookie_files = list(glob.glob(f"{auth_folder}/*.json"))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for c_file in cookie_files:
|
||||
comb = get_site_comb_from_filepath(c_file)
|
||||
for cur_site in comb:
|
||||
url = URLS[SITES.index(cur_site)]
|
||||
keyword = KEYWORDS[SITES.index(cur_site)]
|
||||
match = EXACT_MATCH[SITES.index(cur_site)]
|
||||
future = executor.submit(
|
||||
is_expired, Path(c_file), url, keyword, match
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
for i, future in enumerate(futures):
|
||||
assert not future.result(), f"Cookie {cookie_files[i]} expired."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--site_list", nargs="+", default=["all"])
|
||||
parser.add_argument("--auth_folder", type=str, default="./.auth")
|
||||
args = parser.parse_args()
|
||||
if not args.site_list:
|
||||
main()
|
||||
else:
|
||||
if "all" in args.site_list:
|
||||
main(auth_folder=args.auth_folder)
|
||||
else:
|
||||
renew_comb(args.site_list, auth_folder=args.auth_folder)
|
295
browser_env/constants.py
Normal file
295
browser_env/constants.py
Normal file
|
@ -0,0 +1,295 @@
|
|||
from typing import Literal
|
||||
|
||||
ROLES = (
|
||||
"alert",
|
||||
"alertdialog",
|
||||
"application",
|
||||
"article",
|
||||
"banner",
|
||||
"blockquote",
|
||||
"button",
|
||||
"caption",
|
||||
"cell",
|
||||
"checkbox",
|
||||
"code",
|
||||
"columnheader",
|
||||
"combobox",
|
||||
"complementary",
|
||||
"contentinfo",
|
||||
"definition",
|
||||
"deletion",
|
||||
"dialog",
|
||||
"directory",
|
||||
"document",
|
||||
"emphasis",
|
||||
"feed",
|
||||
"figure",
|
||||
"form",
|
||||
"generic",
|
||||
"grid",
|
||||
"gridcell",
|
||||
"group",
|
||||
"heading",
|
||||
"img",
|
||||
"insertion",
|
||||
"link",
|
||||
"list",
|
||||
"listbox",
|
||||
"listitem",
|
||||
"log",
|
||||
"main",
|
||||
"marquee",
|
||||
"math",
|
||||
"meter",
|
||||
"menu",
|
||||
"menubar",
|
||||
"menuitem",
|
||||
"menuitemcheckbox",
|
||||
"menuitemradio",
|
||||
"navigation",
|
||||
"none",
|
||||
"note",
|
||||
"option",
|
||||
"paragraph",
|
||||
"presentation",
|
||||
"progressbar",
|
||||
"radio",
|
||||
"radiogroup",
|
||||
"region",
|
||||
"row",
|
||||
"rowgroup",
|
||||
"rowheader",
|
||||
"scrollbar",
|
||||
"search",
|
||||
"searchbox",
|
||||
"separator",
|
||||
"slider",
|
||||
"spinbutton",
|
||||
"status",
|
||||
"strong",
|
||||
"subscript",
|
||||
"superscript",
|
||||
"switch",
|
||||
"tab",
|
||||
"table",
|
||||
"tablist",
|
||||
"tabpanel",
|
||||
"term",
|
||||
"textbox",
|
||||
"time",
|
||||
"timer",
|
||||
"toolbar",
|
||||
"tooltip",
|
||||
"tree",
|
||||
"treegrid",
|
||||
"treeitem",
|
||||
)
|
||||
|
||||
SPECIAL_LOCATORS = (
|
||||
"alt_text",
|
||||
"label",
|
||||
"placeholder",
|
||||
)
|
||||
|
||||
ASCII_CHARSET = "".join(chr(x) for x in range(32, 128))
|
||||
FREQ_UNICODE_CHARSET = "".join(chr(x) for x in range(129, 110000))
|
||||
UTTERANCE_MAX_LENGTH = 8192
|
||||
ATTRIBUTE_MAX_LENGTH = 256
|
||||
TEXT_MAX_LENGTH = 256
|
||||
TYPING_MAX_LENGTH = 64
|
||||
URL_MAX_LENGTH = 256
|
||||
MAX_ELEMENT_INDEX_IN_VIEWPORT = 10
|
||||
MAX_ELEMENT_ID = 1000
|
||||
MAX_ANSWER_LENGTH = 512
|
||||
|
||||
MIN_REF = -1000000
|
||||
MAX_REF = 1000000
|
||||
|
||||
WINDOW_WIDTH = 500
|
||||
WINDOW_HEIGHT = 240
|
||||
TASK_WIDTH = 160
|
||||
TASK_HEIGHT = 210
|
||||
|
||||
FLIGHT_WINDOW_WIDTH = 600
|
||||
FLIGHT_WINDOW_HEIGHT = 700
|
||||
FLIGHT_TASK_WIDTH = 375
|
||||
FLIGHT_TASK_HEIGHT = 667
|
||||
MAX_PAGE_NUMBER = 10
|
||||
|
||||
SPECIAL_KEYS = (
|
||||
"Enter",
|
||||
"Tab",
|
||||
"Control",
|
||||
"Shift",
|
||||
"Meta",
|
||||
"Backspace",
|
||||
"Delete",
|
||||
"Escape",
|
||||
"ArrowUp",
|
||||
"ArrowDown",
|
||||
"ArrowLeft",
|
||||
"ArrowRight",
|
||||
"PageDown",
|
||||
"PageUp",
|
||||
"Meta+a",
|
||||
)
|
||||
|
||||
SPECIAL_KEY_MAPPINGS = {
|
||||
"backquote": "Backquote",
|
||||
"minus": "Minus",
|
||||
"equal": "Equal",
|
||||
"backslash": "Backslash",
|
||||
"backspace": "Backspace",
|
||||
"meta": "Meta",
|
||||
"tab": "Tab",
|
||||
"delete": "Delete",
|
||||
"escape": "Escape",
|
||||
"arrowdown": "ArrowDown",
|
||||
"end": "End",
|
||||
"enter": "Enter",
|
||||
"home": "Home",
|
||||
"insert": "Insert",
|
||||
"pagedown": "PageDown",
|
||||
"pageup": "PageUp",
|
||||
"arrowright": "ArrowRight",
|
||||
"arrowup": "ArrowUp",
|
||||
"f1": "F1",
|
||||
"f2": "F2",
|
||||
"f3": "F3",
|
||||
"f4": "F4",
|
||||
"f5": "F5",
|
||||
"f6": "F6",
|
||||
"f7": "F7",
|
||||
"f8": "F8",
|
||||
"f9": "F9",
|
||||
"f10": "F10",
|
||||
"f11": "F11",
|
||||
"f12": "F12",
|
||||
}
|
||||
|
||||
RolesType = Literal[
|
||||
"alert",
|
||||
"alertdialog",
|
||||
"application",
|
||||
"article",
|
||||
"banner",
|
||||
"blockquote",
|
||||
"button",
|
||||
"caption",
|
||||
"cell",
|
||||
"checkbox",
|
||||
"code",
|
||||
"columnheader",
|
||||
"combobox",
|
||||
"complementary",
|
||||
"contentinfo",
|
||||
"definition",
|
||||
"deletion",
|
||||
"dialog",
|
||||
"directory",
|
||||
"document",
|
||||
"emphasis",
|
||||
"feed",
|
||||
"figure",
|
||||
"form",
|
||||
"generic",
|
||||
"grid",
|
||||
"gridcell",
|
||||
"group",
|
||||
"heading",
|
||||
"img",
|
||||
"insertion",
|
||||
"link",
|
||||
"list",
|
||||
"listbox",
|
||||
"listitem",
|
||||
"log",
|
||||
"main",
|
||||
"marquee",
|
||||
"math",
|
||||
"meter",
|
||||
"menu",
|
||||
"menubar",
|
||||
"menuitem",
|
||||
"menuitemcheckbox",
|
||||
"menuitemradio",
|
||||
"navigation",
|
||||
"none",
|
||||
"note",
|
||||
"option",
|
||||
"paragraph",
|
||||
"presentation",
|
||||
"progressbar",
|
||||
"radio",
|
||||
"radiogroup",
|
||||
"region",
|
||||
"row",
|
||||
"rowgroup",
|
||||
"rowheader",
|
||||
"scrollbar",
|
||||
"search",
|
||||
"searchbox",
|
||||
"separator",
|
||||
"slider",
|
||||
"spinbutton",
|
||||
"status",
|
||||
"strong",
|
||||
"subscript",
|
||||
"superscript",
|
||||
"switch",
|
||||
"tab",
|
||||
"table",
|
||||
"tablist",
|
||||
"tabpanel",
|
||||
"term",
|
||||
"textbox",
|
||||
"time",
|
||||
"timer",
|
||||
"toolbar",
|
||||
"tooltip",
|
||||
"tree",
|
||||
"treegrid",
|
||||
"treeitem",
|
||||
"alt_text",
|
||||
"label",
|
||||
"placeholder",
|
||||
]
|
||||
|
||||
MAX_VANILLA_STR_LENGTH = 1000
|
||||
|
||||
PLAYWRIGHT_LOCATORS = (
|
||||
"get_by_role",
|
||||
"get_by_text",
|
||||
"get_by_label",
|
||||
"get_by_placeholder",
|
||||
"get_by_alt_text",
|
||||
"get_by_title",
|
||||
"get_by_test_id",
|
||||
"filter",
|
||||
"frame_locator",
|
||||
"locator",
|
||||
)
|
||||
|
||||
PLAYWRIGHT_ACTIONS = (
|
||||
"fill",
|
||||
"check",
|
||||
"select_option",
|
||||
"click",
|
||||
"hover",
|
||||
"dclick",
|
||||
"type",
|
||||
"focus",
|
||||
"goto",
|
||||
"press",
|
||||
"scroll",
|
||||
)
|
||||
|
||||
IGNORED_ACTREE_PROPERTIES = (
|
||||
"focusable",
|
||||
"editable",
|
||||
"readonly",
|
||||
"level",
|
||||
"settable",
|
||||
"multiline",
|
||||
"invalid",
|
||||
)
|
51
browser_env/env_config.py
Normal file
51
browser_env/env_config.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
# websites domain
|
||||
import os
|
||||
|
||||
REDDIT = os.environ.get("REDDIT", "http://localhost:9999")
|
||||
SHOPPING = os.environ.get("SHOPPING", "http://localhost:7770")
|
||||
SHOPPING_ADMIN = os.environ.get("SHOPPING_ADMIN", "http://localhost:7780/admin")
|
||||
GITLAB = os.environ.get("GITLAB", "http://localhost:8023")
|
||||
WIKIPEDIA = os.environ.get("WIKIPEDIA", "http://localhost:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing")
|
||||
MAP = os.environ.get("MAP", "http://localhost:3000")
|
||||
HOMEPAGE = os.environ.get("HOMEPAGE", "http://localhost:4399")
|
||||
|
||||
assert (
|
||||
REDDIT
|
||||
and SHOPPING
|
||||
and SHOPPING_ADMIN
|
||||
and GITLAB
|
||||
and WIKIPEDIA
|
||||
and MAP
|
||||
and HOMEPAGE
|
||||
), (
|
||||
f"Please setup the URLs to each site. Current: \n"
|
||||
+ f"Reddit: {REDDIT}\n"
|
||||
+ f"Shopping: {SHOPPING}\n"
|
||||
+ f"Shopping Admin: {SHOPPING_ADMIN}\n"
|
||||
+ f"Gitlab: {GITLAB}\n"
|
||||
+ f"Wikipedia: {WIKIPEDIA}\n"
|
||||
+ f"Map: {MAP}\n"
|
||||
+ f"Homepage: {HOMEPAGE}\n"
|
||||
)
|
||||
|
||||
|
||||
ACCOUNTS = {
|
||||
"reddit": {"username": "MarvelsGrantMan136", "password": "test1234"},
|
||||
"gitlab": {"username": "byteblaze", "password": "hello1234"},
|
||||
"shopping": {
|
||||
"username": "emma.lopez@gmail.com",
|
||||
"password": "Password.123",
|
||||
},
|
||||
"shopping_admin": {"username": "admin", "password": "admin1234"},
|
||||
"shopping_site_admin": {"username": "admin", "password": "admin1234"},
|
||||
}
|
||||
|
||||
URL_MAPPINGS = {
|
||||
REDDIT: "http://reddit.com",
|
||||
SHOPPING: "http://onestopmarket.com",
|
||||
SHOPPING_ADMIN: "http://luma.com/admin",
|
||||
GITLAB: "http://gitlab.com",
|
||||
WIKIPEDIA: "http://wikipedia.org",
|
||||
MAP: "http://openstreetmap.org",
|
||||
HOMEPAGE: "http://homepage.com",
|
||||
}
|
334
browser_env/envs.py
Normal file
334
browser_env/envs.py
Normal file
|
@ -0,0 +1,334 @@
|
|||
import json
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from beartype import beartype
|
||||
from beartype.door import is_bearable
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Text
|
||||
from playwright.sync_api import (
|
||||
CDPSession,
|
||||
Page,
|
||||
Playwright,
|
||||
ViewportSize,
|
||||
expect,
|
||||
sync_playwright,
|
||||
)
|
||||
|
||||
from .actions import Action, execute_action, get_action_space
|
||||
from .processors import ObservationHandler, ObservationMetadata
|
||||
from .utils import (
|
||||
AccessibilityTree,
|
||||
DetachedPage,
|
||||
Observation,
|
||||
png_bytes_to_numpy,
|
||||
)
|
||||
|
||||
import base64
|
||||
from .scripts import *
|
||||
|
||||
@dataclass
|
||||
class PlaywrightScript:
|
||||
function: str # goto, get_by_role
|
||||
destination: str # https://www.google.com/, combobox
|
||||
name: str | None = None # Search, Avatar 2009
|
||||
operation: str | None = None # click, fill, press
|
||||
value: str | None = None # avatar movie, Enter
|
||||
|
||||
|
||||
def parse_action(action: str) -> PlaywrightScript:
|
||||
splitted = action.strip().split(" ")
|
||||
assert len(splitted) >= 2
|
||||
match splitted[:2]:
|
||||
case ["goto", url]:
|
||||
assert len(splitted) == 2
|
||||
return PlaywrightScript("goto", url)
|
||||
case ["get_by_role", destination]:
|
||||
assert len(splitted) >= 4
|
||||
match splitted[2:]:
|
||||
case [name, operation]:
|
||||
return PlaywrightScript(
|
||||
"get_by_role", destination, name, operation
|
||||
)
|
||||
case [name, operation, value]:
|
||||
return PlaywrightScript(
|
||||
"get_by_role", destination, name, operation, value
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Invalid action")
|
||||
case _:
|
||||
raise ValueError(f"Invalid action {action}")
|
||||
|
||||
|
||||
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||
"""
|
||||
The goal of this environment is to produce a prototype of a browser environment.
|
||||
In the end, we want to support a fully configurable browser environment with wide
|
||||
range of action spaces and observation spaces, both structured and unstructured.
|
||||
But in this prototype, we just support action space specified by Playwright script,
|
||||
and observation space is the html content of the page.
|
||||
"""
|
||||
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
max_page_length: int = 8192,
|
||||
headless: bool = True,
|
||||
slow_mo: int = 0,
|
||||
observation_type: str = "html",
|
||||
current_viewport_only: bool = False,
|
||||
viewport_size: ViewportSize = {"width": 1280, "height": 720},
|
||||
save_trace_enabled: bool = False,
|
||||
sleep_after_execution: float = 5.0,
|
||||
global_config = None,
|
||||
):
|
||||
# TODO: make Space[Action] = ActionSpace
|
||||
self.action_space = get_action_space() # type: ignore[assignment]
|
||||
self.headless = headless
|
||||
self.slow_mo = slow_mo
|
||||
self.current_viewport_only = current_viewport_only
|
||||
self.reset_finished = False
|
||||
self.viewport_size = viewport_size
|
||||
self.save_trace_enabled = save_trace_enabled
|
||||
self.sleep_after_execution = sleep_after_execution
|
||||
self.global_config = global_config
|
||||
|
||||
match observation_type:
|
||||
case "html" | "accessibility_tree":
|
||||
self.text_observation_type = observation_type
|
||||
self.image_observation_type = ""
|
||||
self.main_observation_type = "text"
|
||||
case "image":
|
||||
self.image_observation_type = observation_type
|
||||
self.text_observation_type = "" # type: ignore[assignment]
|
||||
self.main_observation_type = "image"
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported observation type: {observation_type}"
|
||||
)
|
||||
|
||||
self.observation_handler = ObservationHandler(
|
||||
self.main_observation_type,
|
||||
self.text_observation_type,
|
||||
self.image_observation_type,
|
||||
self.current_viewport_only,
|
||||
self.viewport_size,
|
||||
)
|
||||
|
||||
self.observation_space = (
|
||||
self.observation_handler.get_observation_space()
|
||||
)
|
||||
|
||||
@beartype
|
||||
def setup(self, config_file: Path | None = None) -> None:
|
||||
def handle_dialog(dialog):
|
||||
self.page.dialog_message = dialog.message
|
||||
dialog.dismiss()
|
||||
self.context_manager = sync_playwright()
|
||||
self.playwright = self.context_manager.__enter__()
|
||||
self.browser = self.playwright.chromium.launch(
|
||||
headless=self.headless, slow_mo=self.slow_mo
|
||||
)
|
||||
|
||||
if config_file:
|
||||
with open(config_file, "r") as f:
|
||||
instance_config = json.load(f)
|
||||
else:
|
||||
instance_config = {}
|
||||
|
||||
storage_state = instance_config.get("storage_state", None)
|
||||
start_url = instance_config.get("start_url", None)
|
||||
geolocation = instance_config.get("geolocation", None)
|
||||
|
||||
self.context = self.browser.new_context(
|
||||
viewport=self.viewport_size,
|
||||
storage_state=storage_state,
|
||||
geolocation=geolocation,
|
||||
device_scale_factor=1,
|
||||
)
|
||||
if self.save_trace_enabled:
|
||||
self.context.tracing.start(screenshots=True, snapshots=True)
|
||||
if start_url:
|
||||
start_urls = start_url.split(" |AND| ")
|
||||
for url in start_urls:
|
||||
page = self.context.new_page()
|
||||
page.on("dialog", handle_dialog)
|
||||
client = page.context.new_cdp_session(
|
||||
page
|
||||
) # talk to chrome devtools
|
||||
if self.text_observation_type == "accessibility_tree":
|
||||
client.send("Accessibility.enable")
|
||||
page.client = client # type: ignore # TODO[shuyanzh], fix this hackey client
|
||||
page.goto(url)
|
||||
# set the first page as the current page
|
||||
self.page = self.context.pages[0]
|
||||
self.page.bring_to_front()
|
||||
else:
|
||||
self.page = self.context.new_page()
|
||||
page.on("dialog", handle_dialog)
|
||||
client = self.page.context.new_cdp_session(self.page)
|
||||
if self.text_observation_type == "accessibility_tree":
|
||||
client.send("Accessibility.enable")
|
||||
self.page.client = client # type: ignore
|
||||
|
||||
def get_page_client(self, page: Page) -> CDPSession:
|
||||
return page.client # type: ignore
|
||||
|
||||
def _get_obs(self) -> dict[str, Observation]:
|
||||
obs = self.observation_handler.get_observation(
|
||||
self.page, self.get_page_client(self.page)
|
||||
)
|
||||
return obs
|
||||
|
||||
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
|
||||
metadata = self.observation_handler.get_observation_metadata()
|
||||
return metadata
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, str] | None = None,
|
||||
) -> tuple[dict[str, Observation], dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment.
|
||||
:param options: options for the environment. The current supported options are:
|
||||
- "storage_state": the storage state of the browser. It is a file path to a json file.
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
if self.reset_finished:
|
||||
self.context_manager.__exit__()
|
||||
|
||||
if options is not None and "config_file" in options:
|
||||
config_file = Path(options["config_file"])
|
||||
if config_file.exists():
|
||||
self.setup(config_file=config_file)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file} does not exist.")
|
||||
else:
|
||||
self.setup()
|
||||
self.reset_finished = True
|
||||
|
||||
if self.sleep_after_execution > 0:
|
||||
time.sleep(self.sleep_after_execution)
|
||||
|
||||
images = self.modify_page()
|
||||
|
||||
observation = self._get_obs()
|
||||
observation_metadata = self._get_obs_metadata()
|
||||
info = {
|
||||
"page": DetachedPage(self.page.url, ""),
|
||||
"fail_error": "",
|
||||
"observation_metadata": observation_metadata,
|
||||
"images": images,
|
||||
}
|
||||
|
||||
return (observation, info)
|
||||
|
||||
def save_trace(self, trace_path: str | Path) -> None:
|
||||
if self.save_trace_enabled:
|
||||
self.context.tracing.stop(path=trace_path)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.reset_finished:
|
||||
self.context_manager.__exit__()
|
||||
|
||||
def step(
|
||||
self, action: Action
|
||||
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
|
||||
if not self.reset_finished:
|
||||
raise RuntimeError("Call reset first before calling step.")
|
||||
|
||||
success = False
|
||||
fail_error = ""
|
||||
try:
|
||||
self.page = execute_action(
|
||||
action,
|
||||
self.page,
|
||||
self.context,
|
||||
self.observation_handler.action_processor,
|
||||
)
|
||||
success = True
|
||||
except Exception as e:
|
||||
fail_error = str(e)
|
||||
raise e
|
||||
|
||||
# hard sleep TODO[shuyanzh] suboptimal, may need to check network
|
||||
if self.sleep_after_execution > 0:
|
||||
time.sleep(self.sleep_after_execution)
|
||||
|
||||
images = self.modify_page()
|
||||
|
||||
observation = self._get_obs()
|
||||
observation_metadata = self._get_obs_metadata()
|
||||
|
||||
info = {
|
||||
"page": DetachedPage(self.page.url, self.page.content()),
|
||||
"fail_error": fail_error,
|
||||
"observation_metadata": observation_metadata,
|
||||
"images": images,
|
||||
}
|
||||
|
||||
msg = (
|
||||
observation,
|
||||
float(success), # reward
|
||||
False, # terminated
|
||||
False, # truncated
|
||||
info,
|
||||
)
|
||||
return msg
|
||||
|
||||
def modify_page(self):
|
||||
self.page.wait_for_timeout(500)
|
||||
try:
|
||||
self.page.evaluate(remove_id_script)
|
||||
except:
|
||||
pass
|
||||
|
||||
suffix = getattr(self.global_config, "logname", "")
|
||||
if suffix:
|
||||
img_bytes = self.page.screenshot(path=f"output/screenshot-{suffix}.png", full_page=True)
|
||||
else:
|
||||
img_bytes = self.page.screenshot(path="output/screenshot_raw.png")
|
||||
raw_image = base64.b64encode(img_bytes).decode()
|
||||
|
||||
self.page.evaluate(mix_marker_script)
|
||||
self.page.wait_for_timeout(100)
|
||||
|
||||
# get all clickable elements
|
||||
start_id = 0
|
||||
elem_items, start_id = self.page.evaluate(get_rect_script, {
|
||||
"selector": ".possible-clickable-element",
|
||||
"startIndex": start_id
|
||||
})
|
||||
|
||||
# get ocr items
|
||||
ocr_items = []
|
||||
# ocr_items = page.evaluate(canva_handler_script)
|
||||
# svg_items, _ = page.evaluate(get_rect_script, {"selector": "svg", "startIndex": -1})
|
||||
# ocr_items = ocr_items + svg_items
|
||||
# ocr_items, start_id = get_canva_images(ocr_items, img_bytes, start_id)
|
||||
|
||||
items = elem_items + ocr_items
|
||||
|
||||
# mark our own labels and get the images
|
||||
items = self.page.evaluate(label_marker_script, items)
|
||||
if suffix:
|
||||
img_bytes = self.page.screenshot(path=f"output/marked-{suffix}.png", full_page=True)
|
||||
else:
|
||||
img_bytes = self.page.screenshot(path="output/marked.png")
|
||||
marked_image = base64.b64encode(img_bytes).decode()
|
||||
|
||||
self.page.evaluate(remove_label_mark_script)
|
||||
|
||||
return {
|
||||
"raw_image": raw_image,
|
||||
"marked_image": marked_image,
|
||||
}
|
307
browser_env/helper_functions.py
Normal file
307
browser_env/helper_functions.py
Normal file
|
@ -0,0 +1,307 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from agent.prompts import *
|
||||
from browser_env import (
|
||||
Action,
|
||||
ActionTypes,
|
||||
ObservationMetadata,
|
||||
StateInfo,
|
||||
action2str,
|
||||
)
|
||||
|
||||
HTML_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<head>
|
||||
<style>
|
||||
pre {{
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<html>
|
||||
<body>
|
||||
{body}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def get_render_action(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
) -> str:
|
||||
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
|
||||
match action_set_tag:
|
||||
case "id_html_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "id_html_nasc_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
return action_str
|
||||
|
||||
|
||||
def get_action_description(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
prompt_constructor: PromptConstructor | None,
|
||||
) -> str:
|
||||
"""Generate the text version of the predicted actions to store in action history for prompt use.
|
||||
May contain hint information to recover from the failures"""
|
||||
|
||||
match action_set_tag:
|
||||
case "id_html_tree":
|
||||
# old_op_prompt = "Website: %s; Thinking process: %s; Html segment: %s; Operation: %s; Result: %s"
|
||||
op_prompt = "Html segment: %s; Operation: %s;"
|
||||
text_meta_data = observation_metadata["text"]
|
||||
node_info = text_meta_data["obs_nodes_info"]
|
||||
result = 'Operation Success'
|
||||
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in node_info:
|
||||
node_content = node_info[action["element_id"]]["text"]
|
||||
node_content = " ".join(node_content.split()[1:])
|
||||
action["label"] = node_info[action["element_id"]]["label"]
|
||||
action_str = action2str(
|
||||
action, action_set_tag, node_content
|
||||
)
|
||||
else:
|
||||
action_str = "None"
|
||||
result = f"Cannot find the corresponding tag. Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
text = action["answer"]
|
||||
if text is not None and text.count("#Record#") > 0:
|
||||
action_str = text
|
||||
else:
|
||||
action_str = "None"
|
||||
result = f'Operation invalid. The format was incorrect. Ensure that the action is wrapped inside a pair of # and seperate arguments within spaces as follows: #action# arg1 arg2 ....'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
# action_str = op_prompt % (
|
||||
# prompt_constructor.state["url"],
|
||||
# prompt_constructor.state["intention"],
|
||||
# prompt_constructor.state["segment"],
|
||||
# action_str,
|
||||
# result,
|
||||
# )
|
||||
|
||||
action_str = op_prompt % (
|
||||
prompt_constructor.state["segment"],
|
||||
action_str,
|
||||
)
|
||||
case "id_html_nasc_tree":
|
||||
op_prompt = "%s #HTML Segment: %s"
|
||||
text_meta_data = observation_metadata["text"]
|
||||
node_info = text_meta_data["obs_nodes_info"]
|
||||
result = 'Operation Success'
|
||||
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in node_info:
|
||||
node_content = node_info[action["element_id"]]["text"]
|
||||
node_content = " ".join(node_content.split()[1:])
|
||||
action["label"] = node_info[action["element_id"]]["label"]
|
||||
action_str = action2str(
|
||||
action, action_set_tag, node_content
|
||||
)
|
||||
else:
|
||||
action_str = "None"
|
||||
result = f"Cannot find the corresponding tag. Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
text = action["answer"]
|
||||
if text is not None and text.count("record") > 0:
|
||||
action_str = text
|
||||
else:
|
||||
action_str = "None"
|
||||
result = f'Operation invalid. The format was incorrect. Ensure that the action is wrapped inside a pair of # and seperate arguments within spaces as follows: #action# arg1 arg2 ....'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
action_str = op_prompt % (
|
||||
action_str,
|
||||
prompt_constructor.state["segment"],
|
||||
)
|
||||
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
node_content = " ".join(node_content.split()[1:])
|
||||
action_str = action2str(
|
||||
action, action_set_tag, node_content
|
||||
)
|
||||
else:
|
||||
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
action_splitter = prompt_constructor.instruction[
|
||||
"meta_data"
|
||||
]["action_splitter"]
|
||||
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
|
||||
return action_str
|
||||
|
||||
|
||||
class RenderHelper(object):
|
||||
"""Helper class to render text and image observations and meta data in the trajectory"""
|
||||
|
||||
def __init__(
|
||||
self, config_file: str, result_dir: str, action_set_tag: str
|
||||
) -> None:
|
||||
with open(config_file, "r") as f:
|
||||
_config = json.load(f)
|
||||
_config_str = ""
|
||||
for k, v in _config.items():
|
||||
_config_str += f"{k}: {v}\n"
|
||||
_config_str = f"<pre>{_config_str}</pre>\n"
|
||||
task_id = _config["task_id"]
|
||||
|
||||
self.action_set_tag = action_set_tag
|
||||
|
||||
self.render_file = open(
|
||||
Path(result_dir) / f"render_{task_id}.html", "a+"
|
||||
)
|
||||
self.render_file.truncate(0)
|
||||
# write init template
|
||||
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
|
||||
self.render_file.read()
|
||||
self.render_file.flush()
|
||||
|
||||
def render(
|
||||
self,
|
||||
action: Action,
|
||||
state_info: StateInfo,
|
||||
meta_data: dict[str, Any],
|
||||
render_screenshot: bool = False,
|
||||
) -> None:
|
||||
"""Render the trajectory"""
|
||||
# text observation
|
||||
observation = state_info["observation"]
|
||||
text_obs = observation["text"]
|
||||
info = state_info["info"]
|
||||
new_content = f"<h2>New Page</h2>\n"
|
||||
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
|
||||
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
|
||||
|
||||
if render_screenshot:
|
||||
# image observation
|
||||
img_obs = observation["image"]
|
||||
image = Image.fromarray(img_obs)
|
||||
byte_io = io.BytesIO()
|
||||
image.save(byte_io, format="PNG")
|
||||
byte_io.seek(0)
|
||||
image_bytes = base64.b64encode(byte_io.read())
|
||||
image_str = image_bytes.decode("utf-8")
|
||||
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
|
||||
|
||||
# meta data
|
||||
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
|
||||
|
||||
# action
|
||||
action_str = get_render_action(
|
||||
action,
|
||||
info["observation_metadata"],
|
||||
action_set_tag=self.action_set_tag,
|
||||
)
|
||||
# with yellow background
|
||||
action_str = f"<div class='predict_action'>{action_str}</div>"
|
||||
new_content += f"{action_str}\n"
|
||||
|
||||
# add new content
|
||||
self.render_file.seek(0)
|
||||
html = self.render_file.read()
|
||||
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
|
||||
html_body += new_content
|
||||
|
||||
html = HTML_TEMPLATE.format(body=html_body)
|
||||
self.render_file.seek(0)
|
||||
self.render_file.truncate()
|
||||
self.render_file.write(html)
|
||||
self.render_file.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
self.render_file.close()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user