pass sample test web with custom orm
This commit is contained in:
parent
8f168ecbef
commit
ee08da12c0
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,2 +1,5 @@
|
||||
output/
|
||||
result/
|
||||
result/
|
||||
sample/sample_output/
|
||||
|
||||
*__pycache__/
|
35
.vscode/launch.json
vendored
Normal file
35
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "sample_test_web",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "/home/default/miniconda3/envs/swift/bin/swift",
|
||||
"cwd": "${workspaceFolder}/sample",
|
||||
"console": "integratedTerminal",
|
||||
"python": "/home/default/miniconda3/envs/swift/bin/python3.10",
|
||||
"justMyCode": false,
|
||||
"env": {
|
||||
"OPENAI_API_KEY": "sk-a6a00cb727ee4913ba22530ec3c2b30d",
|
||||
},
|
||||
"args": [
|
||||
"sample",
|
||||
"--model", "qwen3-8b",
|
||||
"--sampler_type", "distill",
|
||||
"--sampler_engine", "client",
|
||||
"--stream", "true",
|
||||
"--orm_model", "external_web_acc",
|
||||
"--dataset", "combine_output_file_3.jsonl",
|
||||
"--num_return_sequences", "1",
|
||||
"--temperature", "0.6",
|
||||
"--top_p", "0.95",
|
||||
"--external_plugins", "plugin.py",
|
||||
"--engine_kwargs", "{\"base_url\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\"}"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
597
plugin.py
Normal file
597
plugin.py
Normal file
@ -0,0 +1,597 @@
|
||||
import asyncio
|
||||
import re
|
||||
import textwrap
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import json
|
||||
import torch
|
||||
|
||||
from swift.llm import PtEngine, RequestConfig, Template, to_device
|
||||
from swift.llm.infer.protocol import ChatCompletionResponse
|
||||
from swift.plugin import ORM, orms, rm_plugins
|
||||
from swift.plugin.rm_plugin import DefaultRMPlugin
|
||||
from swift.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
"""
|
||||
Step 1: Define a Reward Class
|
||||
Implement your custom reward calculation logic within the __call__ method.
|
||||
The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
|
||||
|
||||
Step 2: Register the Reward Class in orms
|
||||
For example:
|
||||
python orms['external_math_acc'] = MathAccuracy
|
||||
|
||||
Step 3: Configure the Arguments
|
||||
Use the following arguments when running the script:
|
||||
bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
|
||||
"""
|
||||
|
||||
|
||||
# Code borrowed from plugin/orm.py
|
||||
class MathAccuracy(ORM):
|
||||
|
||||
def __init__(self):
|
||||
import importlib.util
|
||||
assert importlib.util.find_spec('math_verify') is not None, (
|
||||
"The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
|
||||
|
||||
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
rewards = []
|
||||
for content, sol in zip(completions, solution):
|
||||
gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()])
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode='first_match',
|
||||
)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
reward = float(verify(answer_parsed, gold_parsed))
|
||||
else:
|
||||
# If the gold solution is not parseable, we reward 1 to skip this example
|
||||
reward = 1.0
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
|
||||
class MathFormat(ORM):
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
"""Reward function that checks if the completion has a specific format."""
|
||||
pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])'
|
||||
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
class CountdownORM(ORM):
|
||||
|
||||
def __call__(self, completions, target, nums, **kwargs) -> List[float]:
|
||||
"""
|
||||
Evaluates completions based on Mathematical correctness of the answer
|
||||
|
||||
Args:
|
||||
completions (list[str]): Generated outputs
|
||||
target (list[str]): Expected answers
|
||||
nums (list[str]): Available numbers
|
||||
|
||||
Returns:
|
||||
list[float]: Reward scores
|
||||
"""
|
||||
rewards = []
|
||||
for completion, gt, numbers in zip(completions, target, nums):
|
||||
try:
|
||||
# Check if the format is correct
|
||||
match = re.search(r'<answer>(.*?)<\/answer>', completion)
|
||||
if match is None:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
# Extract the "answer" part from the completion
|
||||
equation = match.group(1).strip()
|
||||
if '=' in equation:
|
||||
equation = equation.split('=')[0]
|
||||
# Extract all numbers from the equation
|
||||
used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
|
||||
|
||||
# Check if all numbers are used exactly once
|
||||
if sorted(used_numbers) != sorted(numbers):
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
|
||||
allowed_pattern = r'^[\d+\-*/().\s]+$'
|
||||
if not re.match(allowed_pattern, equation):
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
# Evaluate the equation with restricted globals and locals
|
||||
result = eval(equation, {"__builti'ns__": None}, {})
|
||||
# Check if the equation is correct and matches the ground truth
|
||||
if abs(float(result) - float(gt)) < 1e-5:
|
||||
rewards.append(1.0)
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
except Exception:
|
||||
# If evaluation fails, reward is 0
|
||||
rewards.append(0.0)
|
||||
return rewards
|
||||
|
||||
|
||||
class MultiModalAccuracyORM(ORM):
|
||||
|
||||
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
||||
"""
|
||||
Reward function that checks if the completion is correct.
|
||||
Args:
|
||||
completions (list[str]): Generated outputs
|
||||
solution (list[str]): Ground Truths.
|
||||
|
||||
Returns:
|
||||
list[float]: Reward scores
|
||||
"""
|
||||
rewards = []
|
||||
from math_verify import parse, verify
|
||||
for content, sol in zip(completions, solution):
|
||||
reward = 0.0
|
||||
# Try symbolic verification first
|
||||
try:
|
||||
answer = parse(content)
|
||||
if float(verify(answer, parse(sol))) > 0:
|
||||
reward = 1.0
|
||||
except Exception:
|
||||
pass # Continue to next verification method if this fails
|
||||
|
||||
# If symbolic verification failed, try string matching
|
||||
if reward == 0.0:
|
||||
try:
|
||||
# Extract answer from solution if it has think/answer tags
|
||||
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
||||
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
||||
|
||||
# Extract answer from content if it has think/answer tags
|
||||
content_match = re.search(r'<answer>(.*?)</answer>', content)
|
||||
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
||||
|
||||
# Compare the extracted answers
|
||||
if student_answer == ground_truth:
|
||||
reward = 1.0
|
||||
except Exception:
|
||||
pass # Keep reward as 0.0 if both methods fail
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
|
||||
# ref implementation: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
|
||||
class CodeReward(ORM):
|
||||
|
||||
def __init__(self):
|
||||
import importlib.util
|
||||
assert importlib.util.find_spec('e2b') is not None, (
|
||||
"The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'."
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
@staticmethod
|
||||
def extract_code(completion: str, language: str) -> str:
|
||||
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
|
||||
matches = pattern.findall(completion)
|
||||
extracted_answer = matches[-1] if len(matches) >= 1 else ''
|
||||
return extracted_answer
|
||||
|
||||
def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]:
|
||||
"""Function wrapping the `run_async` function."""
|
||||
# Create a new event loop and set it
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Run the async function and get the result
|
||||
rewards = loop.run_until_complete(self.run_async(scripts, languages))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return rewards
|
||||
|
||||
async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]:
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
|
||||
# Create the sandbox by hand, currently there's no context manager for this version
|
||||
try:
|
||||
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from E2B executor: {e}')
|
||||
return [0.0] * len(scripts)
|
||||
# Create a list of tasks for running scripts concurrently
|
||||
tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)]
|
||||
|
||||
# Wait for all tasks to complete and gather their results as they finish
|
||||
results = await asyncio.gather(*tasks)
|
||||
rewards = list(results) # collect results
|
||||
|
||||
# Kill the sandbox after all the tasks are complete
|
||||
await sbx.kill()
|
||||
|
||||
return rewards
|
||||
|
||||
async def run_script(self, sbx, script: str, language: str) -> float:
|
||||
try:
|
||||
execution = await sbx.run_code(script, language=language, timeout=30)
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from E2B executor: {e}')
|
||||
return 0.0
|
||||
try:
|
||||
return float(execution.text)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
"""Reward function that evaluates code snippets using the E2B code interpreter.
|
||||
|
||||
Assumes the dataset contains a `verification_info` column with test cases.
|
||||
"""
|
||||
evaluation_script_template = """
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
def evaluate_code(code, test_cases):
|
||||
passed = 0
|
||||
total = len(test_cases)
|
||||
exec_timeout = 5
|
||||
|
||||
for case in test_cases:
|
||||
process = subprocess.run(
|
||||
["python3", "-c", code],
|
||||
input=case["input"],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=exec_timeout
|
||||
)
|
||||
|
||||
if process.returncode != 0: # Error in execution
|
||||
continue
|
||||
|
||||
output = process.stdout.strip()
|
||||
if output.strip() == case["output"].strip():
|
||||
passed += 1
|
||||
|
||||
success_rate = (passed / total)
|
||||
return success_rate
|
||||
|
||||
code_snippet = {code}
|
||||
test_cases = json.loads({test_cases})
|
||||
|
||||
evaluate_code(code_snippet, test_cases)
|
||||
"""
|
||||
verification_info = kwargs['verification_info']
|
||||
languages = [info['language'] for info in verification_info]
|
||||
code_snippets = [
|
||||
self.extract_code(completion, language) for completion, language in zip(completions, languages)
|
||||
]
|
||||
scripts = [
|
||||
evaluation_script_template.format(
|
||||
code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases'])))
|
||||
for code, info in zip(code_snippets, verification_info)
|
||||
]
|
||||
try:
|
||||
rewards = self.run_async_from_sync(scripts, languages)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from E2B executor: {e}')
|
||||
rewards = [0.0] * len(completions)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
class CodeFormat(ORM):
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
verification_info = kwargs['verification_info']
|
||||
rewards = []
|
||||
for content, info in zip(completions, verification_info):
|
||||
pattern = r'^<think>.*?</think>\s*<answer>.*?```{}.*?```.*?</answer>(?![\s\S])'.format(info['language'])
|
||||
match = re.match(pattern, content, re.DOTALL | re.MULTILINE)
|
||||
reward = 1.0 if match else 0.0
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
|
||||
class CodeRewardByJudge0(ORM):
|
||||
LANGUAGE_ID_MAP = {
|
||||
'assembly': 45,
|
||||
'bash': 46,
|
||||
'basic': 47,
|
||||
'c': 50,
|
||||
'c++': 54,
|
||||
'clojure': 86,
|
||||
'c#': 51,
|
||||
'cobol': 77,
|
||||
'common lisp': 55,
|
||||
'd': 56,
|
||||
'elixir': 57,
|
||||
'erlang': 58,
|
||||
'executable': 44,
|
||||
'f#': 87,
|
||||
'fortran': 59,
|
||||
'go': 60,
|
||||
'groovy': 88,
|
||||
'haskell': 61,
|
||||
'java': 62,
|
||||
'javascript': 63,
|
||||
'kotlin': 78,
|
||||
'lua': 64,
|
||||
'multi-file program': 89,
|
||||
'objective-c': 79,
|
||||
'ocaml': 65,
|
||||
'octave': 66,
|
||||
'pascal': 67,
|
||||
'perl': 85,
|
||||
'php': 68,
|
||||
'plain text': 43,
|
||||
'prolog': 69,
|
||||
'python': 71,
|
||||
'python2': 70,
|
||||
'python3': 71,
|
||||
'r': 80,
|
||||
'ruby': 72,
|
||||
'rust': 73,
|
||||
'scala': 81,
|
||||
'sql': 82,
|
||||
'swift': 83,
|
||||
'typescript': 74,
|
||||
'visual basic.net': 84
|
||||
}
|
||||
PYTHON_ID = 71
|
||||
|
||||
def __init__(self):
|
||||
import os
|
||||
self.endpoint = os.getenv('JUDGE0_ENDPOINT')
|
||||
assert self.endpoint is not None, (
|
||||
'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.')
|
||||
x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN')
|
||||
self.headers = {'Content-Type': 'application/json'}
|
||||
if x_auth_token is not None:
|
||||
self.headers['X-Auth-Token'] = x_auth_token
|
||||
|
||||
@staticmethod
|
||||
def extract_code(completion: str, language: str) -> str:
|
||||
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
|
||||
matches = pattern.findall(completion)
|
||||
extracted_answer = matches[-1] if len(matches) >= 1 else ''
|
||||
return extracted_answer
|
||||
|
||||
@classmethod
|
||||
def get_language_id(cls, language):
|
||||
if language is None:
|
||||
return cls.PYTHON_ID
|
||||
return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID)
|
||||
|
||||
async def _evaluate_code(self, code, test_cases, language_id):
|
||||
import aiohttp
|
||||
try:
|
||||
passed = 0
|
||||
total = len(test_cases)
|
||||
|
||||
for case in test_cases:
|
||||
if code is not None and code != '':
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
'source_code': code,
|
||||
'language_id': language_id,
|
||||
'stdin': case['input'],
|
||||
'expected_output': case['output']
|
||||
}
|
||||
logger.debug(f'Payload: {payload}')
|
||||
async with session.post(
|
||||
self.endpoint + '/submissions/?wait=true', json=payload,
|
||||
headers=self.headers) as response:
|
||||
response_json = await response.json()
|
||||
logger.debug(f'Response: {response_json}')
|
||||
if response_json['status']['description'] == 'Accepted':
|
||||
passed += 1
|
||||
|
||||
success_rate = (passed / total)
|
||||
return success_rate
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from Judge0 executor: {e}')
|
||||
return 0.0
|
||||
|
||||
def run_async_from_sync(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
rewards = loop.run_until_complete(self.run_async())
|
||||
finally:
|
||||
loop.close()
|
||||
return rewards
|
||||
|
||||
async def run_async(self):
|
||||
tasks = [
|
||||
self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language']))
|
||||
for code, info in zip(self.code_snippets, self.verification_info)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
rewards = list(results)
|
||||
return rewards
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
self.verification_info = kwargs['verification_info']
|
||||
|
||||
languages = [info['language'] for info in self.verification_info]
|
||||
self.code_snippets = [
|
||||
self.extract_code(completion, language) for completion, language in zip(completions, languages)
|
||||
]
|
||||
|
||||
try:
|
||||
rewards = self.run_async_from_sync()
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from Judge0 executor: {e}')
|
||||
rewards = [0.0] * len(completions)
|
||||
return rewards
|
||||
|
||||
|
||||
orms['external_math_acc'] = MathAccuracy
|
||||
orms['external_math_format'] = MathFormat
|
||||
orms['external_countdown'] = CountdownORM
|
||||
orms['external_r1v_acc'] = MultiModalAccuracyORM
|
||||
orms['external_code_reward'] = CodeReward
|
||||
orms['external_code_format'] = CodeFormat
|
||||
orms['external_code_reward_by_judge0'] = CodeRewardByJudge0
|
||||
|
||||
|
||||
# For genrm you can refer to swift/llm/plugin/rm_plugin/GenRMPlugin
|
||||
class CustomizedRMPlugin:
|
||||
"""
|
||||
Customized Reward Model Plugin, same to DefaultRMPlugin
|
||||
|
||||
It assumes that `self.model` is a classification model with a value head(output dimmension 1).
|
||||
The first logits value from the model's output is used as the reward score.
|
||||
"""
|
||||
|
||||
def __init__(self, model, template):
|
||||
self.model = model
|
||||
self.template: Template = template
|
||||
|
||||
def __call__(self, inputs):
|
||||
batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
|
||||
reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
|
||||
reward_inputs.pop('labels')
|
||||
|
||||
with torch.inference_mode():
|
||||
return self.model(**reward_inputs).logits[:, 0]
|
||||
|
||||
|
||||
class QwenLongPlugin(DefaultRMPlugin):
|
||||
# https://arxiv.org/abs/2505.17667
|
||||
# NOTE: you should customize the verified reward function, you can refer to
|
||||
# https://github.com/Tongyi-Zhiwen/QwenLong-L1/tree/main/verl/verl/utils/reward_score
|
||||
# hf_dataset: https://huggingface.co/datasets/Tongyi-Zhiwen/DocQA-RL-1.6K/viewer/default/train
|
||||
# ms_dataset: https://modelscope.cn/datasets/iic/DocQA-RL-1.6K
|
||||
def __init__(self, model, template, accuracy_orm=None):
|
||||
super().__init__(model, template)
|
||||
# initilize PTEngine to infer
|
||||
self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit
|
||||
self.request_config = RequestConfig(temperature=0) # customise your request config here
|
||||
self.system = textwrap.dedent("""
|
||||
You are an expert in verifying if two answers are the same.
|
||||
|
||||
Your input consists of a problem and two answers: Answer 1 and Answer 2.
|
||||
You need to check if they are equivalent.
|
||||
|
||||
Your task is to determine if the two answers are equivalent, without attempting to solve the original problem.
|
||||
Compare the answers to verify they represent identical values or meanings,
|
||||
even when expressed in different forms or notations.
|
||||
|
||||
Your output must follow this format:
|
||||
1) Provide an explanation for why the answers are equivalent or not.
|
||||
2) Then provide your final answer in the form of: [[YES]] or [[NO]]
|
||||
|
||||
Problem: {problem_placeholder}
|
||||
Answer 1: {answer1_placeholder}
|
||||
Answer 2: {answer2_placeholder}
|
||||
""") # noqa
|
||||
self.accuracy_orm = accuracy_orm
|
||||
|
||||
def __call__(self, inputs):
|
||||
completions = [example['messages'][-1]['content'] for example in inputs]
|
||||
ground_truths = [example['reward_model']['ground_truth'] for example in inputs]
|
||||
rm_inputs = self.prepare_rm_inputs(inputs, completions, ground_truths)
|
||||
|
||||
results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
|
||||
llm_rewards = self.compute_rewards(results)
|
||||
|
||||
if self.accuracy_orm:
|
||||
verified_rewards = self.accuracy_orm(completions, ground_truths)
|
||||
else:
|
||||
verified_rewards = [0.0] * len(llm_rewards)
|
||||
|
||||
rewards = [max(r1, r2) for r1, r2 in zip(llm_rewards, verified_rewards)]
|
||||
return torch.tensor(rewards, dtype=torch.float32)
|
||||
|
||||
def prepare_rm_inputs(self, inputs: List[Dict], completions, ground_truths) -> List[Dict]:
|
||||
rm_inputs = []
|
||||
for infer_request, completion, ground_truth in zip(inputs, completions, ground_truths):
|
||||
# Deep copy to prevent modification of original input
|
||||
rm_infer_request = deepcopy(infer_request)
|
||||
problem = infer_request['messages'][0]['content']
|
||||
start_index = problem.index('</text>')
|
||||
end_index = problem.index('Format your response as follows:')
|
||||
question = problem[start_index:end_index].replace('</text>', '').strip()
|
||||
prompt = self.system.format(
|
||||
problem_placeholder=question, answer1_placeholder=completion, answer2_placeholder=ground_truth)
|
||||
|
||||
# Construct new messages tailored for the reward model
|
||||
rm_messages = [{'role': 'user', 'content': prompt}]
|
||||
|
||||
# Update the messages in the reward infer request
|
||||
rm_infer_request['messages'] = rm_messages
|
||||
rm_inputs.append(rm_infer_request)
|
||||
return rm_inputs
|
||||
|
||||
@staticmethod
|
||||
def extract_reward(model_output: str) -> float:
|
||||
match = re.search(r'\[([A-Z]+)\]', model_output)
|
||||
if match:
|
||||
answer = match.group(1)
|
||||
if answer == 'YES':
|
||||
return 1.0
|
||||
elif answer == 'NO':
|
||||
return 0.0
|
||||
else:
|
||||
logger.warning("Unexpected answer, expected 'YES' or 'NO'.")
|
||||
return 0.0
|
||||
else:
|
||||
logger.warning("Unable to extract reward score from the model's output, setting reward to 0")
|
||||
return 0.0 # Or raise ValueError("Format incorrect")
|
||||
|
||||
def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
|
||||
"""
|
||||
Compute average reward scores from the reward model's outputs.
|
||||
|
||||
Args:
|
||||
results (List[ChatCompletionResponse]): A list of results from the reward model.
|
||||
|
||||
Returns:
|
||||
List[float]: A list of average reward scores.
|
||||
"""
|
||||
rewards = []
|
||||
for idx, output in enumerate(results):
|
||||
try:
|
||||
cur_rewards = []
|
||||
for choice in output.choices:
|
||||
response = choice.message.content
|
||||
reward = self.extract_reward(response)
|
||||
cur_rewards.append(reward)
|
||||
cur_rewards = [r for r in cur_rewards if r is not None]
|
||||
if cur_rewards:
|
||||
average_reward = sum(cur_rewards) / len(cur_rewards)
|
||||
else:
|
||||
average_reward = 0.0
|
||||
logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
|
||||
|
||||
rewards.append(average_reward)
|
||||
except Exception as e:
|
||||
logger.error(f'Error computing reward: {e}')
|
||||
rewards.append(0.0) # Assign default reward score on failure
|
||||
return rewards
|
||||
|
||||
|
||||
rm_plugins['my_rmplugin'] = CustomizedRMPlugin
|
||||
rm_plugins['qwenlong'] = QwenLongPlugin
|
1
sample/combine_output_file_3.jsonl
Normal file
1
sample/combine_output_file_3.jsonl
Normal file
File diff suppressed because one or more lines are too long
633
sample/plugin.py
Normal file
633
sample/plugin.py
Normal file
@ -0,0 +1,633 @@
|
||||
import asyncio
|
||||
import re
|
||||
import textwrap
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import json
|
||||
import torch
|
||||
|
||||
from swift.llm import PtEngine, RequestConfig, Template, to_device
|
||||
from swift.llm.infer.protocol import ChatCompletionResponse
|
||||
from swift.plugin import ORM, orms, rm_plugins
|
||||
from swift.plugin.rm_plugin import DefaultRMPlugin
|
||||
from swift.utils import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
"""
|
||||
Step 1: Define a Reward Class
|
||||
Implement your custom reward calculation logic within the __call__ method.
|
||||
The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
|
||||
|
||||
Step 2: Register the Reward Class in orms
|
||||
For example:
|
||||
python orms['external_math_acc'] = MathAccuracy
|
||||
|
||||
Step 3: Configure the Arguments
|
||||
Use the following arguments when running the script:
|
||||
bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
|
||||
"""
|
||||
|
||||
|
||||
# Code borrowed from plugin/orm.py
|
||||
class MathAccuracy(ORM):
|
||||
|
||||
def __init__(self):
|
||||
import importlib.util
|
||||
assert importlib.util.find_spec('math_verify') is not None, (
|
||||
"The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
|
||||
|
||||
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
rewards = []
|
||||
for content, sol in zip(completions, solution):
|
||||
gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()])
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode='first_match',
|
||||
)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
reward = float(verify(answer_parsed, gold_parsed))
|
||||
else:
|
||||
# If the gold solution is not parseable, we reward 1 to skip this example
|
||||
reward = 1.0
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
|
||||
class MathFormat(ORM):
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
"""Reward function that checks if the completion has a specific format."""
|
||||
pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])'
|
||||
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
class CountdownORM(ORM):
|
||||
|
||||
def __call__(self, completions, target, nums, **kwargs) -> List[float]:
|
||||
"""
|
||||
Evaluates completions based on Mathematical correctness of the answer
|
||||
|
||||
Args:
|
||||
completions (list[str]): Generated outputs
|
||||
target (list[str]): Expected answers
|
||||
nums (list[str]): Available numbers
|
||||
|
||||
Returns:
|
||||
list[float]: Reward scores
|
||||
"""
|
||||
rewards = []
|
||||
for completion, gt, numbers in zip(completions, target, nums):
|
||||
try:
|
||||
# Check if the format is correct
|
||||
match = re.search(r'<answer>(.*?)<\/answer>', completion)
|
||||
if match is None:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
# Extract the "answer" part from the completion
|
||||
equation = match.group(1).strip()
|
||||
if '=' in equation:
|
||||
equation = equation.split('=')[0]
|
||||
# Extract all numbers from the equation
|
||||
used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
|
||||
|
||||
# Check if all numbers are used exactly once
|
||||
if sorted(used_numbers) != sorted(numbers):
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
|
||||
allowed_pattern = r'^[\d+\-*/().\s]+$'
|
||||
if not re.match(allowed_pattern, equation):
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
# Evaluate the equation with restricted globals and locals
|
||||
result = eval(equation, {"__builti'ns__": None}, {})
|
||||
# Check if the equation is correct and matches the ground truth
|
||||
if abs(float(result) - float(gt)) < 1e-5:
|
||||
rewards.append(1.0)
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
except Exception:
|
||||
# If evaluation fails, reward is 0
|
||||
rewards.append(0.0)
|
||||
return rewards
|
||||
|
||||
|
||||
class MultiModalAccuracyORM(ORM):
|
||||
|
||||
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
||||
"""
|
||||
Reward function that checks if the completion is correct.
|
||||
Args:
|
||||
completions (list[str]): Generated outputs
|
||||
solution (list[str]): Ground Truths.
|
||||
|
||||
Returns:
|
||||
list[float]: Reward scores
|
||||
"""
|
||||
rewards = []
|
||||
from math_verify import parse, verify
|
||||
for content, sol in zip(completions, solution):
|
||||
reward = 0.0
|
||||
# Try symbolic verification first
|
||||
try:
|
||||
answer = parse(content)
|
||||
if float(verify(answer, parse(sol))) > 0:
|
||||
reward = 1.0
|
||||
except Exception:
|
||||
pass # Continue to next verification method if this fails
|
||||
|
||||
# If symbolic verification failed, try string matching
|
||||
if reward == 0.0:
|
||||
try:
|
||||
# Extract answer from solution if it has think/answer tags
|
||||
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
||||
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
||||
|
||||
# Extract answer from content if it has think/answer tags
|
||||
content_match = re.search(r'<answer>(.*?)</answer>', content)
|
||||
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
||||
|
||||
# Compare the extracted answers
|
||||
if student_answer == ground_truth:
|
||||
reward = 1.0
|
||||
except Exception:
|
||||
pass # Keep reward as 0.0 if both methods fail
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
|
||||
# ref implementation: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
|
||||
class CodeReward(ORM):
|
||||
|
||||
def __init__(self):
|
||||
import importlib.util
|
||||
assert importlib.util.find_spec('e2b') is not None, (
|
||||
"The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'."
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
@staticmethod
|
||||
def extract_code(completion: str, language: str) -> str:
|
||||
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
|
||||
matches = pattern.findall(completion)
|
||||
extracted_answer = matches[-1] if len(matches) >= 1 else ''
|
||||
return extracted_answer
|
||||
|
||||
def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]:
|
||||
"""Function wrapping the `run_async` function."""
|
||||
# Create a new event loop and set it
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Run the async function and get the result
|
||||
rewards = loop.run_until_complete(self.run_async(scripts, languages))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return rewards
|
||||
|
||||
async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]:
|
||||
from e2b_code_interpreter import AsyncSandbox
|
||||
|
||||
# Create the sandbox by hand, currently there's no context manager for this version
|
||||
try:
|
||||
sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from E2B executor: {e}')
|
||||
return [0.0] * len(scripts)
|
||||
# Create a list of tasks for running scripts concurrently
|
||||
tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)]
|
||||
|
||||
# Wait for all tasks to complete and gather their results as they finish
|
||||
results = await asyncio.gather(*tasks)
|
||||
rewards = list(results) # collect results
|
||||
|
||||
# Kill the sandbox after all the tasks are complete
|
||||
await sbx.kill()
|
||||
|
||||
return rewards
|
||||
|
||||
async def run_script(self, sbx, script: str, language: str) -> float:
|
||||
try:
|
||||
execution = await sbx.run_code(script, language=language, timeout=30)
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from E2B executor: {e}')
|
||||
return 0.0
|
||||
try:
|
||||
return float(execution.text)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
"""Reward function that evaluates code snippets using the E2B code interpreter.
|
||||
|
||||
Assumes the dataset contains a `verification_info` column with test cases.
|
||||
"""
|
||||
evaluation_script_template = """
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
def evaluate_code(code, test_cases):
|
||||
passed = 0
|
||||
total = len(test_cases)
|
||||
exec_timeout = 5
|
||||
|
||||
for case in test_cases:
|
||||
process = subprocess.run(
|
||||
["python3", "-c", code],
|
||||
input=case["input"],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=exec_timeout
|
||||
)
|
||||
|
||||
if process.returncode != 0: # Error in execution
|
||||
continue
|
||||
|
||||
output = process.stdout.strip()
|
||||
if output.strip() == case["output"].strip():
|
||||
passed += 1
|
||||
|
||||
success_rate = (passed / total)
|
||||
return success_rate
|
||||
|
||||
code_snippet = {code}
|
||||
test_cases = json.loads({test_cases})
|
||||
|
||||
evaluate_code(code_snippet, test_cases)
|
||||
"""
|
||||
verification_info = kwargs['verification_info']
|
||||
languages = [info['language'] for info in verification_info]
|
||||
code_snippets = [
|
||||
self.extract_code(completion, language) for completion, language in zip(completions, languages)
|
||||
]
|
||||
scripts = [
|
||||
evaluation_script_template.format(
|
||||
code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases'])))
|
||||
for code, info in zip(code_snippets, verification_info)
|
||||
]
|
||||
try:
|
||||
rewards = self.run_async_from_sync(scripts, languages)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from E2B executor: {e}')
|
||||
rewards = [0.0] * len(completions)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
class CodeFormat(ORM):
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
verification_info = kwargs['verification_info']
|
||||
rewards = []
|
||||
for content, info in zip(completions, verification_info):
|
||||
pattern = r'^<think>.*?</think>\s*<answer>.*?```{}.*?```.*?</answer>(?![\s\S])'.format(info['language'])
|
||||
match = re.match(pattern, content, re.DOTALL | re.MULTILINE)
|
||||
reward = 1.0 if match else 0.0
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
|
||||
class CodeRewardByJudge0(ORM):
|
||||
LANGUAGE_ID_MAP = {
|
||||
'assembly': 45,
|
||||
'bash': 46,
|
||||
'basic': 47,
|
||||
'c': 50,
|
||||
'c++': 54,
|
||||
'clojure': 86,
|
||||
'c#': 51,
|
||||
'cobol': 77,
|
||||
'common lisp': 55,
|
||||
'd': 56,
|
||||
'elixir': 57,
|
||||
'erlang': 58,
|
||||
'executable': 44,
|
||||
'f#': 87,
|
||||
'fortran': 59,
|
||||
'go': 60,
|
||||
'groovy': 88,
|
||||
'haskell': 61,
|
||||
'java': 62,
|
||||
'javascript': 63,
|
||||
'kotlin': 78,
|
||||
'lua': 64,
|
||||
'multi-file program': 89,
|
||||
'objective-c': 79,
|
||||
'ocaml': 65,
|
||||
'octave': 66,
|
||||
'pascal': 67,
|
||||
'perl': 85,
|
||||
'php': 68,
|
||||
'plain text': 43,
|
||||
'prolog': 69,
|
||||
'python': 71,
|
||||
'python2': 70,
|
||||
'python3': 71,
|
||||
'r': 80,
|
||||
'ruby': 72,
|
||||
'rust': 73,
|
||||
'scala': 81,
|
||||
'sql': 82,
|
||||
'swift': 83,
|
||||
'typescript': 74,
|
||||
'visual basic.net': 84
|
||||
}
|
||||
PYTHON_ID = 71
|
||||
|
||||
def __init__(self):
|
||||
import os
|
||||
self.endpoint = os.getenv('JUDGE0_ENDPOINT')
|
||||
assert self.endpoint is not None, (
|
||||
'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.')
|
||||
x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN')
|
||||
self.headers = {'Content-Type': 'application/json'}
|
||||
if x_auth_token is not None:
|
||||
self.headers['X-Auth-Token'] = x_auth_token
|
||||
|
||||
@staticmethod
|
||||
def extract_code(completion: str, language: str) -> str:
|
||||
pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
|
||||
matches = pattern.findall(completion)
|
||||
extracted_answer = matches[-1] if len(matches) >= 1 else ''
|
||||
return extracted_answer
|
||||
|
||||
@classmethod
|
||||
def get_language_id(cls, language):
|
||||
if language is None:
|
||||
return cls.PYTHON_ID
|
||||
return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID)
|
||||
|
||||
async def _evaluate_code(self, code, test_cases, language_id):
|
||||
import aiohttp
|
||||
try:
|
||||
passed = 0
|
||||
total = len(test_cases)
|
||||
|
||||
for case in test_cases:
|
||||
if code is not None and code != '':
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
'source_code': code,
|
||||
'language_id': language_id,
|
||||
'stdin': case['input'],
|
||||
'expected_output': case['output']
|
||||
}
|
||||
logger.debug(f'Payload: {payload}')
|
||||
async with session.post(
|
||||
self.endpoint + '/submissions/?wait=true', json=payload,
|
||||
headers=self.headers) as response:
|
||||
response_json = await response.json()
|
||||
logger.debug(f'Response: {response_json}')
|
||||
if response_json['status']['description'] == 'Accepted':
|
||||
passed += 1
|
||||
|
||||
success_rate = (passed / total)
|
||||
return success_rate
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from Judge0 executor: {e}')
|
||||
return 0.0
|
||||
|
||||
def run_async_from_sync(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
rewards = loop.run_until_complete(self.run_async())
|
||||
finally:
|
||||
loop.close()
|
||||
return rewards
|
||||
|
||||
async def run_async(self):
|
||||
tasks = [
|
||||
self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language']))
|
||||
for code, info in zip(self.code_snippets, self.verification_info)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
rewards = list(results)
|
||||
return rewards
|
||||
|
||||
def __call__(self, completions, **kwargs) -> List[float]:
|
||||
self.verification_info = kwargs['verification_info']
|
||||
|
||||
languages = [info['language'] for info in self.verification_info]
|
||||
self.code_snippets = [
|
||||
self.extract_code(completion, language) for completion, language in zip(completions, languages)
|
||||
]
|
||||
|
||||
try:
|
||||
rewards = self.run_async_from_sync()
|
||||
except Exception as e:
|
||||
logger.warning(f'Error from Judge0 executor: {e}')
|
||||
rewards = [0.0] * len(completions)
|
||||
return rewards
|
||||
|
||||
|
||||
orms['external_math_acc'] = MathAccuracy
|
||||
orms['external_math_format'] = MathFormat
|
||||
orms['external_countdown'] = CountdownORM
|
||||
orms['external_r1v_acc'] = MultiModalAccuracyORM
|
||||
orms['external_code_reward'] = CodeReward
|
||||
orms['external_code_format'] = CodeFormat
|
||||
orms['external_code_reward_by_judge0'] = CodeRewardByJudge0
|
||||
|
||||
|
||||
# For genrm you can refer to swift/llm/plugin/rm_plugin/GenRMPlugin
|
||||
class CustomizedRMPlugin:
|
||||
"""
|
||||
Customized Reward Model Plugin, same to DefaultRMPlugin
|
||||
|
||||
It assumes that `self.model` is a classification model with a value head(output dimmension 1).
|
||||
The first logits value from the model's output is used as the reward score.
|
||||
"""
|
||||
|
||||
def __init__(self, model, template):
|
||||
self.model = model
|
||||
self.template: Template = template
|
||||
|
||||
def __call__(self, inputs):
|
||||
batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
|
||||
reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
|
||||
reward_inputs.pop('labels')
|
||||
|
||||
with torch.inference_mode():
|
||||
return self.model(**reward_inputs).logits[:, 0]
|
||||
|
||||
|
||||
class QwenLongPlugin(DefaultRMPlugin):
|
||||
# https://arxiv.org/abs/2505.17667
|
||||
# NOTE: you should customize the verified reward function, you can refer to
|
||||
# https://github.com/Tongyi-Zhiwen/QwenLong-L1/tree/main/verl/verl/utils/reward_score
|
||||
# hf_dataset: https://huggingface.co/datasets/Tongyi-Zhiwen/DocQA-RL-1.6K/viewer/default/train
|
||||
# ms_dataset: https://modelscope.cn/datasets/iic/DocQA-RL-1.6K
|
||||
def __init__(self, model, template, accuracy_orm=None):
|
||||
super().__init__(model, template)
|
||||
# initilize PTEngine to infer
|
||||
self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit
|
||||
self.request_config = RequestConfig(temperature=0) # customise your request config here
|
||||
self.system = textwrap.dedent("""
|
||||
You are an expert in verifying if two answers are the same.
|
||||
|
||||
Your input consists of a problem and two answers: Answer 1 and Answer 2.
|
||||
You need to check if they are equivalent.
|
||||
|
||||
Your task is to determine if the two answers are equivalent, without attempting to solve the original problem.
|
||||
Compare the answers to verify they represent identical values or meanings,
|
||||
even when expressed in different forms or notations.
|
||||
|
||||
Your output must follow this format:
|
||||
1) Provide an explanation for why the answers are equivalent or not.
|
||||
2) Then provide your final answer in the form of: [[YES]] or [[NO]]
|
||||
|
||||
Problem: {problem_placeholder}
|
||||
Answer 1: {answer1_placeholder}
|
||||
Answer 2: {answer2_placeholder}
|
||||
""") # noqa
|
||||
self.accuracy_orm = accuracy_orm
|
||||
|
||||
def __call__(self, inputs):
|
||||
completions = [example['messages'][-1]['content'] for example in inputs]
|
||||
ground_truths = [example['reward_model']['ground_truth'] for example in inputs]
|
||||
rm_inputs = self.prepare_rm_inputs(inputs, completions, ground_truths)
|
||||
|
||||
results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
|
||||
llm_rewards = self.compute_rewards(results)
|
||||
|
||||
if self.accuracy_orm:
|
||||
verified_rewards = self.accuracy_orm(completions, ground_truths)
|
||||
else:
|
||||
verified_rewards = [0.0] * len(llm_rewards)
|
||||
|
||||
rewards = [max(r1, r2) for r1, r2 in zip(llm_rewards, verified_rewards)]
|
||||
return torch.tensor(rewards, dtype=torch.float32)
|
||||
|
||||
def prepare_rm_inputs(self, inputs: List[Dict], completions, ground_truths) -> List[Dict]:
|
||||
rm_inputs = []
|
||||
for infer_request, completion, ground_truth in zip(inputs, completions, ground_truths):
|
||||
# Deep copy to prevent modification of original input
|
||||
rm_infer_request = deepcopy(infer_request)
|
||||
problem = infer_request['messages'][0]['content']
|
||||
start_index = problem.index('</text>')
|
||||
end_index = problem.index('Format your response as follows:')
|
||||
question = problem[start_index:end_index].replace('</text>', '').strip()
|
||||
prompt = self.system.format(
|
||||
problem_placeholder=question, answer1_placeholder=completion, answer2_placeholder=ground_truth)
|
||||
|
||||
# Construct new messages tailored for the reward model
|
||||
rm_messages = [{'role': 'user', 'content': prompt}]
|
||||
|
||||
# Update the messages in the reward infer request
|
||||
rm_infer_request['messages'] = rm_messages
|
||||
rm_inputs.append(rm_infer_request)
|
||||
return rm_inputs
|
||||
|
||||
@staticmethod
|
||||
def extract_reward(model_output: str) -> float:
|
||||
match = re.search(r'\[([A-Z]+)\]', model_output)
|
||||
if match:
|
||||
answer = match.group(1)
|
||||
if answer == 'YES':
|
||||
return 1.0
|
||||
elif answer == 'NO':
|
||||
return 0.0
|
||||
else:
|
||||
logger.warning("Unexpected answer, expected 'YES' or 'NO'.")
|
||||
return 0.0
|
||||
else:
|
||||
logger.warning("Unable to extract reward score from the model's output, setting reward to 0")
|
||||
return 0.0 # Or raise ValueError("Format incorrect")
|
||||
|
||||
def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
|
||||
"""
|
||||
Compute average reward scores from the reward model's outputs.
|
||||
|
||||
Args:
|
||||
results (List[ChatCompletionResponse]): A list of results from the reward model.
|
||||
|
||||
Returns:
|
||||
List[float]: A list of average reward scores.
|
||||
"""
|
||||
rewards = []
|
||||
for idx, output in enumerate(results):
|
||||
try:
|
||||
cur_rewards = []
|
||||
for choice in output.choices:
|
||||
response = choice.message.content
|
||||
reward = self.extract_reward(response)
|
||||
cur_rewards.append(reward)
|
||||
cur_rewards = [r for r in cur_rewards if r is not None]
|
||||
if cur_rewards:
|
||||
average_reward = sum(cur_rewards) / len(cur_rewards)
|
||||
else:
|
||||
average_reward = 0.0
|
||||
logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
|
||||
|
||||
rewards.append(average_reward)
|
||||
except Exception as e:
|
||||
logger.error(f'Error computing reward: {e}')
|
||||
rewards.append(0.0) # Assign default reward score on failure
|
||||
return rewards
|
||||
|
||||
|
||||
rm_plugins['my_rmplugin'] = CustomizedRMPlugin
|
||||
rm_plugins['qwenlong'] = QwenLongPlugin
|
||||
|
||||
|
||||
class WebAccuracy(ORM):
|
||||
|
||||
def __call__(self, completions, ground_truths, **kwargs) -> List[float]:
|
||||
solution = ground_truths
|
||||
rewards = []
|
||||
solution_pattern = r'\b(do|exit|go_backward|go_forward)\(.*\)'
|
||||
answer_pattern = r'<answer>(.*?)</answer>'
|
||||
|
||||
for content, sol in zip(completions, solution):
|
||||
reward = 0.0
|
||||
|
||||
# 1. Extract reference from solution
|
||||
sol_match = re.search(solution_pattern, sol, re.DOTALL)
|
||||
if sol_match:
|
||||
reference = sol_match.group(0).strip()
|
||||
|
||||
answer_text = content.messages[-1]['content']
|
||||
# 2. Extract answer from completion
|
||||
answer_match = re.search(answer_pattern, answer_text, re.DOTALL)
|
||||
if answer_match:
|
||||
answer_text = answer_match.group(1).strip()
|
||||
|
||||
# get solution form
|
||||
solution_match = re.search(solution_pattern, answer_text, re.DOTALL)
|
||||
if solution_match:
|
||||
solution_text = solution_match.group(0).strip()
|
||||
if reference in solution_text:
|
||||
reward = 1.0
|
||||
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
|
||||
orms['external_web_acc'] = WebAccuracy
|
14
sample/sample_test.sh
Normal file
14
sample/sample_test.sh
Normal file
@ -0,0 +1,14 @@
|
||||
MODEL="qwen3-8b"
|
||||
|
||||
OPENAI_API_KEY="sk-a6a00cb727ee4913ba22530ec3c2b30d" \
|
||||
swift sample \
|
||||
--sampler_type distill \
|
||||
--sampler_engine client \
|
||||
--model $MODEL \
|
||||
--stream true \
|
||||
--orm_model math \
|
||||
--dataset AI-MO/NuminaMath-TIR#5 \
|
||||
--num_return_sequences 1 \
|
||||
--temperature 0.6 \
|
||||
--top_p 0.95 \
|
||||
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
|
15
sample/sample_test_web.sh
Normal file
15
sample/sample_test_web.sh
Normal file
@ -0,0 +1,15 @@
|
||||
MODEL="qwen3-8b"
|
||||
|
||||
OPENAI_API_KEY="sk-a6a00cb727ee4913ba22530ec3c2b30d" \
|
||||
swift sample \
|
||||
--sampler_type distill \
|
||||
--sampler_engine client \
|
||||
--model $MODEL \
|
||||
--stream true \
|
||||
--orm_model external_web_acc \
|
||||
--dataset combine_output_file_3.jsonl \
|
||||
--num_return_sequences 1 \
|
||||
--temperature 0.6 \
|
||||
--top_p 0.95 \
|
||||
--external_plugins plugin.py \
|
||||
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
|
Loading…
x
Reference in New Issue
Block a user