pass sample test web with custom orm
This commit is contained in:
		
							parent
							
								
									8f168ecbef
								
							
						
					
					
						commit
						ee08da12c0
					
				
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -1,2 +1,5 @@
 | 
			
		||||
output/
 | 
			
		||||
result/
 | 
			
		||||
result/
 | 
			
		||||
sample/sample_output/
 | 
			
		||||
 | 
			
		||||
*__pycache__/
 | 
			
		||||
							
								
								
									
										35
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								.vscode/launch.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
{
 | 
			
		||||
    // Use IntelliSense to learn about possible attributes.
 | 
			
		||||
    // Hover to view descriptions of existing attributes.
 | 
			
		||||
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
 | 
			
		||||
    "version": "0.2.0",
 | 
			
		||||
    "configurations": [
 | 
			
		||||
        {
 | 
			
		||||
            "name": "sample_test_web",
 | 
			
		||||
            "type": "debugpy",
 | 
			
		||||
            "request": "launch",
 | 
			
		||||
            "program": "/home/default/miniconda3/envs/swift/bin/swift",
 | 
			
		||||
            "cwd": "${workspaceFolder}/sample",
 | 
			
		||||
            "console": "integratedTerminal",
 | 
			
		||||
            "python": "/home/default/miniconda3/envs/swift/bin/python3.10",
 | 
			
		||||
            "justMyCode": false,
 | 
			
		||||
            "env": {
 | 
			
		||||
                "OPENAI_API_KEY": "sk-a6a00cb727ee4913ba22530ec3c2b30d",
 | 
			
		||||
            },
 | 
			
		||||
            "args": [
 | 
			
		||||
                "sample",
 | 
			
		||||
                "--model", "qwen3-8b",
 | 
			
		||||
                "--sampler_type", "distill",
 | 
			
		||||
                "--sampler_engine", "client",
 | 
			
		||||
                "--stream", "true",
 | 
			
		||||
                "--orm_model", "external_web_acc",
 | 
			
		||||
                "--dataset", "combine_output_file_3.jsonl",
 | 
			
		||||
                "--num_return_sequences", "1",
 | 
			
		||||
                "--temperature", "0.6",
 | 
			
		||||
                "--top_p", "0.95",
 | 
			
		||||
                "--external_plugins", "plugin.py",
 | 
			
		||||
                "--engine_kwargs", "{\"base_url\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\"}"
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
    ]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										597
									
								
								plugin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										597
									
								
								plugin.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,597 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
import re
 | 
			
		||||
import textwrap
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from typing import Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from swift.llm import PtEngine, RequestConfig, Template, to_device
 | 
			
		||||
from swift.llm.infer.protocol import ChatCompletionResponse
 | 
			
		||||
from swift.plugin import ORM, orms, rm_plugins
 | 
			
		||||
from swift.plugin.rm_plugin import DefaultRMPlugin
 | 
			
		||||
from swift.utils import get_logger
 | 
			
		||||
 | 
			
		||||
logger = get_logger()
 | 
			
		||||
"""
 | 
			
		||||
Step 1: Define a Reward Class
 | 
			
		||||
    Implement your custom reward calculation logic within the __call__ method.
 | 
			
		||||
    The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
 | 
			
		||||
 | 
			
		||||
Step 2: Register the Reward Class in orms
 | 
			
		||||
    For example:
 | 
			
		||||
    python orms['external_math_acc'] = MathAccuracy
 | 
			
		||||
 | 
			
		||||
Step 3: Configure the Arguments
 | 
			
		||||
    Use the following arguments when running the script:
 | 
			
		||||
    bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Code borrowed from plugin/orm.py
 | 
			
		||||
class MathAccuracy(ORM):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        import importlib.util
 | 
			
		||||
        assert importlib.util.find_spec('math_verify') is not None, (
 | 
			
		||||
            "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, solution, **kwargs) -> List[float]:
 | 
			
		||||
        from latex2sympy2_extended import NormalizationConfig
 | 
			
		||||
        from math_verify import LatexExtractionConfig, parse, verify
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for content, sol in zip(completions, solution):
 | 
			
		||||
            gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()])
 | 
			
		||||
            if len(gold_parsed) != 0:
 | 
			
		||||
                # We require the answer to be provided in correct latex (no malformed operators)
 | 
			
		||||
                answer_parsed = parse(
 | 
			
		||||
                    content,
 | 
			
		||||
                    extraction_config=[
 | 
			
		||||
                        LatexExtractionConfig(
 | 
			
		||||
                            normalization_config=NormalizationConfig(
 | 
			
		||||
                                nits=False,
 | 
			
		||||
                                malformed_operators=False,
 | 
			
		||||
                                basic_latex=True,
 | 
			
		||||
                                equations=True,
 | 
			
		||||
                                boxed=True,
 | 
			
		||||
                                units=True,
 | 
			
		||||
                            ),
 | 
			
		||||
                            # Ensures that boxed is tried first
 | 
			
		||||
                            boxed_match_priority=0,
 | 
			
		||||
                            try_extract_without_anchor=False,
 | 
			
		||||
                        )
 | 
			
		||||
                    ],
 | 
			
		||||
                    extraction_mode='first_match',
 | 
			
		||||
                )
 | 
			
		||||
                # Reward 1 if the content is the same as the ground truth, 0 otherwise
 | 
			
		||||
                reward = float(verify(answer_parsed, gold_parsed))
 | 
			
		||||
            else:
 | 
			
		||||
                # If the gold solution is not parseable, we reward 1 to skip this example
 | 
			
		||||
                reward = 1.0
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MathFormat(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        """Reward function that checks if the completion has a specific format."""
 | 
			
		||||
        pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])'
 | 
			
		||||
        matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
 | 
			
		||||
        return [1.0 if match else 0.0 for match in matches]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CountdownORM(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, target, nums, **kwargs) -> List[float]:
 | 
			
		||||
        """
 | 
			
		||||
        Evaluates completions based on Mathematical correctness of the answer
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            completions (list[str]): Generated outputs
 | 
			
		||||
            target (list[str]): Expected answers
 | 
			
		||||
            nums (list[str]): Available numbers
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[float]: Reward scores
 | 
			
		||||
        """
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for completion, gt, numbers in zip(completions, target, nums):
 | 
			
		||||
            try:
 | 
			
		||||
                # Check if the format is correct
 | 
			
		||||
                match = re.search(r'<answer>(.*?)<\/answer>', completion)
 | 
			
		||||
                if match is None:
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
                    continue
 | 
			
		||||
                # Extract the "answer" part from the completion
 | 
			
		||||
                equation = match.group(1).strip()
 | 
			
		||||
                if '=' in equation:
 | 
			
		||||
                    equation = equation.split('=')[0]
 | 
			
		||||
                # Extract all numbers from the equation
 | 
			
		||||
                used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
 | 
			
		||||
 | 
			
		||||
                # Check if all numbers are used exactly once
 | 
			
		||||
                if sorted(used_numbers) != sorted(numbers):
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
                    continue
 | 
			
		||||
                # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
 | 
			
		||||
                allowed_pattern = r'^[\d+\-*/().\s]+$'
 | 
			
		||||
                if not re.match(allowed_pattern, equation):
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # Evaluate the equation with restricted globals and locals
 | 
			
		||||
                result = eval(equation, {"__builti'ns__": None}, {})
 | 
			
		||||
                # Check if the equation is correct and matches the ground truth
 | 
			
		||||
                if abs(float(result) - float(gt)) < 1e-5:
 | 
			
		||||
                    rewards.append(1.0)
 | 
			
		||||
                else:
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                # If evaluation fails, reward is 0
 | 
			
		||||
                rewards.append(0.0)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MultiModalAccuracyORM(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, solution, **kwargs) -> List[float]:
 | 
			
		||||
        """
 | 
			
		||||
        Reward function that checks if the completion is correct.
 | 
			
		||||
        Args:
 | 
			
		||||
            completions (list[str]): Generated outputs
 | 
			
		||||
            solution (list[str]): Ground Truths.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[float]: Reward scores
 | 
			
		||||
        """
 | 
			
		||||
        rewards = []
 | 
			
		||||
        from math_verify import parse, verify
 | 
			
		||||
        for content, sol in zip(completions, solution):
 | 
			
		||||
            reward = 0.0
 | 
			
		||||
            # Try symbolic verification first
 | 
			
		||||
            try:
 | 
			
		||||
                answer = parse(content)
 | 
			
		||||
                if float(verify(answer, parse(sol))) > 0:
 | 
			
		||||
                    reward = 1.0
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass  # Continue to next verification method if this fails
 | 
			
		||||
 | 
			
		||||
            # If symbolic verification failed, try string matching
 | 
			
		||||
            if reward == 0.0:
 | 
			
		||||
                try:
 | 
			
		||||
                    # Extract answer from solution if it has think/answer tags
 | 
			
		||||
                    sol_match = re.search(r'<answer>(.*?)</answer>', sol)
 | 
			
		||||
                    ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
 | 
			
		||||
 | 
			
		||||
                    # Extract answer from content if it has think/answer tags
 | 
			
		||||
                    content_match = re.search(r'<answer>(.*?)</answer>', content)
 | 
			
		||||
                    student_answer = content_match.group(1).strip() if content_match else content.strip()
 | 
			
		||||
 | 
			
		||||
                    # Compare the extracted answers
 | 
			
		||||
                    if student_answer == ground_truth:
 | 
			
		||||
                        reward = 1.0
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    pass  # Keep reward as 0.0 if both methods fail
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ref implementation: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
 | 
			
		||||
class CodeReward(ORM):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        import importlib.util
 | 
			
		||||
        assert importlib.util.find_spec('e2b') is not None, (
 | 
			
		||||
            "The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'."
 | 
			
		||||
        )
 | 
			
		||||
        from dotenv import load_dotenv
 | 
			
		||||
        load_dotenv()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def extract_code(completion: str, language: str) -> str:
 | 
			
		||||
        pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
 | 
			
		||||
        matches = pattern.findall(completion)
 | 
			
		||||
        extracted_answer = matches[-1] if len(matches) >= 1 else ''
 | 
			
		||||
        return extracted_answer
 | 
			
		||||
 | 
			
		||||
    def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]:
 | 
			
		||||
        """Function wrapping the `run_async` function."""
 | 
			
		||||
        # Create a new event loop and set it
 | 
			
		||||
        loop = asyncio.new_event_loop()
 | 
			
		||||
        asyncio.set_event_loop(loop)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # Run the async function and get the result
 | 
			
		||||
            rewards = loop.run_until_complete(self.run_async(scripts, languages))
 | 
			
		||||
        finally:
 | 
			
		||||
            loop.close()
 | 
			
		||||
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]:
 | 
			
		||||
        from e2b_code_interpreter import AsyncSandbox
 | 
			
		||||
 | 
			
		||||
        # Create the sandbox by hand, currently there's no context manager for this version
 | 
			
		||||
        try:
 | 
			
		||||
            sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from E2B executor: {e}')
 | 
			
		||||
            return [0.0] * len(scripts)
 | 
			
		||||
        # Create a list of tasks for running scripts concurrently
 | 
			
		||||
        tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)]
 | 
			
		||||
 | 
			
		||||
        # Wait for all tasks to complete and gather their results as they finish
 | 
			
		||||
        results = await asyncio.gather(*tasks)
 | 
			
		||||
        rewards = list(results)  # collect results
 | 
			
		||||
 | 
			
		||||
        # Kill the sandbox after all the tasks are complete
 | 
			
		||||
        await sbx.kill()
 | 
			
		||||
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    async def run_script(self, sbx, script: str, language: str) -> float:
 | 
			
		||||
        try:
 | 
			
		||||
            execution = await sbx.run_code(script, language=language, timeout=30)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from E2B executor: {e}')
 | 
			
		||||
            return 0.0
 | 
			
		||||
        try:
 | 
			
		||||
            return float(execution.text)
 | 
			
		||||
        except (TypeError, ValueError):
 | 
			
		||||
            return 0.0
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        """Reward function that evaluates code snippets using the E2B code interpreter.
 | 
			
		||||
 | 
			
		||||
        Assumes the dataset contains a `verification_info` column with test cases.
 | 
			
		||||
        """
 | 
			
		||||
        evaluation_script_template = """
 | 
			
		||||
        import subprocess
 | 
			
		||||
        import json
 | 
			
		||||
 | 
			
		||||
        def evaluate_code(code, test_cases):
 | 
			
		||||
            passed = 0
 | 
			
		||||
            total = len(test_cases)
 | 
			
		||||
            exec_timeout = 5
 | 
			
		||||
 | 
			
		||||
            for case in test_cases:
 | 
			
		||||
                process = subprocess.run(
 | 
			
		||||
                    ["python3", "-c", code],
 | 
			
		||||
                    input=case["input"],
 | 
			
		||||
                    text=True,
 | 
			
		||||
                    capture_output=True,
 | 
			
		||||
                    timeout=exec_timeout
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if process.returncode != 0:  # Error in execution
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                output = process.stdout.strip()
 | 
			
		||||
                if output.strip() == case["output"].strip():
 | 
			
		||||
                    passed += 1
 | 
			
		||||
 | 
			
		||||
            success_rate = (passed / total)
 | 
			
		||||
            return success_rate
 | 
			
		||||
 | 
			
		||||
        code_snippet = {code}
 | 
			
		||||
        test_cases = json.loads({test_cases})
 | 
			
		||||
 | 
			
		||||
        evaluate_code(code_snippet, test_cases)
 | 
			
		||||
        """
 | 
			
		||||
        verification_info = kwargs['verification_info']
 | 
			
		||||
        languages = [info['language'] for info in verification_info]
 | 
			
		||||
        code_snippets = [
 | 
			
		||||
            self.extract_code(completion, language) for completion, language in zip(completions, languages)
 | 
			
		||||
        ]
 | 
			
		||||
        scripts = [
 | 
			
		||||
            evaluation_script_template.format(
 | 
			
		||||
                code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases'])))
 | 
			
		||||
            for code, info in zip(code_snippets, verification_info)
 | 
			
		||||
        ]
 | 
			
		||||
        try:
 | 
			
		||||
            rewards = self.run_async_from_sync(scripts, languages)
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from E2B executor: {e}')
 | 
			
		||||
            rewards = [0.0] * len(completions)
 | 
			
		||||
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CodeFormat(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        verification_info = kwargs['verification_info']
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for content, info in zip(completions, verification_info):
 | 
			
		||||
            pattern = r'^<think>.*?</think>\s*<answer>.*?```{}.*?```.*?</answer>(?![\s\S])'.format(info['language'])
 | 
			
		||||
            match = re.match(pattern, content, re.DOTALL | re.MULTILINE)
 | 
			
		||||
            reward = 1.0 if match else 0.0
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CodeRewardByJudge0(ORM):
 | 
			
		||||
    LANGUAGE_ID_MAP = {
 | 
			
		||||
        'assembly': 45,
 | 
			
		||||
        'bash': 46,
 | 
			
		||||
        'basic': 47,
 | 
			
		||||
        'c': 50,
 | 
			
		||||
        'c++': 54,
 | 
			
		||||
        'clojure': 86,
 | 
			
		||||
        'c#': 51,
 | 
			
		||||
        'cobol': 77,
 | 
			
		||||
        'common lisp': 55,
 | 
			
		||||
        'd': 56,
 | 
			
		||||
        'elixir': 57,
 | 
			
		||||
        'erlang': 58,
 | 
			
		||||
        'executable': 44,
 | 
			
		||||
        'f#': 87,
 | 
			
		||||
        'fortran': 59,
 | 
			
		||||
        'go': 60,
 | 
			
		||||
        'groovy': 88,
 | 
			
		||||
        'haskell': 61,
 | 
			
		||||
        'java': 62,
 | 
			
		||||
        'javascript': 63,
 | 
			
		||||
        'kotlin': 78,
 | 
			
		||||
        'lua': 64,
 | 
			
		||||
        'multi-file program': 89,
 | 
			
		||||
        'objective-c': 79,
 | 
			
		||||
        'ocaml': 65,
 | 
			
		||||
        'octave': 66,
 | 
			
		||||
        'pascal': 67,
 | 
			
		||||
        'perl': 85,
 | 
			
		||||
        'php': 68,
 | 
			
		||||
        'plain text': 43,
 | 
			
		||||
        'prolog': 69,
 | 
			
		||||
        'python': 71,
 | 
			
		||||
        'python2': 70,
 | 
			
		||||
        'python3': 71,
 | 
			
		||||
        'r': 80,
 | 
			
		||||
        'ruby': 72,
 | 
			
		||||
        'rust': 73,
 | 
			
		||||
        'scala': 81,
 | 
			
		||||
        'sql': 82,
 | 
			
		||||
        'swift': 83,
 | 
			
		||||
        'typescript': 74,
 | 
			
		||||
        'visual basic.net': 84
 | 
			
		||||
    }
 | 
			
		||||
    PYTHON_ID = 71
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        import os
 | 
			
		||||
        self.endpoint = os.getenv('JUDGE0_ENDPOINT')
 | 
			
		||||
        assert self.endpoint is not None, (
 | 
			
		||||
            'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.')
 | 
			
		||||
        x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN')
 | 
			
		||||
        self.headers = {'Content-Type': 'application/json'}
 | 
			
		||||
        if x_auth_token is not None:
 | 
			
		||||
            self.headers['X-Auth-Token'] = x_auth_token
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def extract_code(completion: str, language: str) -> str:
 | 
			
		||||
        pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
 | 
			
		||||
        matches = pattern.findall(completion)
 | 
			
		||||
        extracted_answer = matches[-1] if len(matches) >= 1 else ''
 | 
			
		||||
        return extracted_answer
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_language_id(cls, language):
 | 
			
		||||
        if language is None:
 | 
			
		||||
            return cls.PYTHON_ID
 | 
			
		||||
        return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID)
 | 
			
		||||
 | 
			
		||||
    async def _evaluate_code(self, code, test_cases, language_id):
 | 
			
		||||
        import aiohttp
 | 
			
		||||
        try:
 | 
			
		||||
            passed = 0
 | 
			
		||||
            total = len(test_cases)
 | 
			
		||||
 | 
			
		||||
            for case in test_cases:
 | 
			
		||||
                if code is not None and code != '':
 | 
			
		||||
                    async with aiohttp.ClientSession() as session:
 | 
			
		||||
                        payload = {
 | 
			
		||||
                            'source_code': code,
 | 
			
		||||
                            'language_id': language_id,
 | 
			
		||||
                            'stdin': case['input'],
 | 
			
		||||
                            'expected_output': case['output']
 | 
			
		||||
                        }
 | 
			
		||||
                        logger.debug(f'Payload: {payload}')
 | 
			
		||||
                        async with session.post(
 | 
			
		||||
                                self.endpoint + '/submissions/?wait=true', json=payload,
 | 
			
		||||
                                headers=self.headers) as response:
 | 
			
		||||
                            response_json = await response.json()
 | 
			
		||||
                            logger.debug(f'Response: {response_json}')
 | 
			
		||||
                            if response_json['status']['description'] == 'Accepted':
 | 
			
		||||
                                passed += 1
 | 
			
		||||
 | 
			
		||||
            success_rate = (passed / total)
 | 
			
		||||
            return success_rate
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from Judge0 executor: {e}')
 | 
			
		||||
            return 0.0
 | 
			
		||||
 | 
			
		||||
    def run_async_from_sync(self):
 | 
			
		||||
        loop = asyncio.new_event_loop()
 | 
			
		||||
        asyncio.set_event_loop(loop)
 | 
			
		||||
        try:
 | 
			
		||||
            rewards = loop.run_until_complete(self.run_async())
 | 
			
		||||
        finally:
 | 
			
		||||
            loop.close()
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    async def run_async(self):
 | 
			
		||||
        tasks = [
 | 
			
		||||
            self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language']))
 | 
			
		||||
            for code, info in zip(self.code_snippets, self.verification_info)
 | 
			
		||||
        ]
 | 
			
		||||
        results = await asyncio.gather(*tasks)
 | 
			
		||||
        rewards = list(results)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        self.verification_info = kwargs['verification_info']
 | 
			
		||||
 | 
			
		||||
        languages = [info['language'] for info in self.verification_info]
 | 
			
		||||
        self.code_snippets = [
 | 
			
		||||
            self.extract_code(completion, language) for completion, language in zip(completions, languages)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            rewards = self.run_async_from_sync()
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from Judge0 executor: {e}')
 | 
			
		||||
            rewards = [0.0] * len(completions)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
orms['external_math_acc'] = MathAccuracy
 | 
			
		||||
orms['external_math_format'] = MathFormat
 | 
			
		||||
orms['external_countdown'] = CountdownORM
 | 
			
		||||
orms['external_r1v_acc'] = MultiModalAccuracyORM
 | 
			
		||||
orms['external_code_reward'] = CodeReward
 | 
			
		||||
orms['external_code_format'] = CodeFormat
 | 
			
		||||
orms['external_code_reward_by_judge0'] = CodeRewardByJudge0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# For genrm you can refer to swift/llm/plugin/rm_plugin/GenRMPlugin
 | 
			
		||||
class CustomizedRMPlugin:
 | 
			
		||||
    """
 | 
			
		||||
    Customized Reward Model Plugin, same to DefaultRMPlugin
 | 
			
		||||
 | 
			
		||||
    It assumes that `self.model` is a classification model with a value head(output dimmension 1).
 | 
			
		||||
    The first logits value from the model's output is used as the reward score.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model, template):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.template: Template = template
 | 
			
		||||
 | 
			
		||||
    def __call__(self, inputs):
 | 
			
		||||
        batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
 | 
			
		||||
        reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
 | 
			
		||||
        reward_inputs.pop('labels')
 | 
			
		||||
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            return self.model(**reward_inputs).logits[:, 0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QwenLongPlugin(DefaultRMPlugin):
 | 
			
		||||
    # https://arxiv.org/abs/2505.17667
 | 
			
		||||
    # NOTE: you should customize the verified reward function, you can refer to
 | 
			
		||||
    # https://github.com/Tongyi-Zhiwen/QwenLong-L1/tree/main/verl/verl/utils/reward_score
 | 
			
		||||
    # hf_dataset: https://huggingface.co/datasets/Tongyi-Zhiwen/DocQA-RL-1.6K/viewer/default/train
 | 
			
		||||
    # ms_dataset: https://modelscope.cn/datasets/iic/DocQA-RL-1.6K
 | 
			
		||||
    def __init__(self, model, template, accuracy_orm=None):
 | 
			
		||||
        super().__init__(model, template)
 | 
			
		||||
        # initilize PTEngine to infer
 | 
			
		||||
        self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0)  # 0: no limit
 | 
			
		||||
        self.request_config = RequestConfig(temperature=0)  # customise your request config here
 | 
			
		||||
        self.system = textwrap.dedent("""
 | 
			
		||||
            You are an expert in verifying if two answers are the same.
 | 
			
		||||
 | 
			
		||||
            Your input consists of a problem and two answers: Answer 1 and Answer 2.
 | 
			
		||||
            You need to check if they are equivalent.
 | 
			
		||||
 | 
			
		||||
            Your task is to determine if the two answers are equivalent, without attempting to solve the original problem.
 | 
			
		||||
            Compare the answers to verify they represent identical values or meanings,
 | 
			
		||||
            even when expressed in different forms or notations.
 | 
			
		||||
 | 
			
		||||
            Your output must follow this format:
 | 
			
		||||
            1) Provide an explanation for why the answers are equivalent or not.
 | 
			
		||||
            2) Then provide your final answer in the form of: [[YES]] or [[NO]]
 | 
			
		||||
 | 
			
		||||
            Problem: {problem_placeholder}
 | 
			
		||||
            Answer 1: {answer1_placeholder}
 | 
			
		||||
            Answer 2: {answer2_placeholder}
 | 
			
		||||
        """)  # noqa
 | 
			
		||||
        self.accuracy_orm = accuracy_orm
 | 
			
		||||
 | 
			
		||||
    def __call__(self, inputs):
 | 
			
		||||
        completions = [example['messages'][-1]['content'] for example in inputs]
 | 
			
		||||
        ground_truths = [example['reward_model']['ground_truth'] for example in inputs]
 | 
			
		||||
        rm_inputs = self.prepare_rm_inputs(inputs, completions, ground_truths)
 | 
			
		||||
 | 
			
		||||
        results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
 | 
			
		||||
        llm_rewards = self.compute_rewards(results)
 | 
			
		||||
 | 
			
		||||
        if self.accuracy_orm:
 | 
			
		||||
            verified_rewards = self.accuracy_orm(completions, ground_truths)
 | 
			
		||||
        else:
 | 
			
		||||
            verified_rewards = [0.0] * len(llm_rewards)
 | 
			
		||||
 | 
			
		||||
        rewards = [max(r1, r2) for r1, r2 in zip(llm_rewards, verified_rewards)]
 | 
			
		||||
        return torch.tensor(rewards, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    def prepare_rm_inputs(self, inputs: List[Dict], completions, ground_truths) -> List[Dict]:
 | 
			
		||||
        rm_inputs = []
 | 
			
		||||
        for infer_request, completion, ground_truth in zip(inputs, completions, ground_truths):
 | 
			
		||||
            # Deep copy to prevent modification of original input
 | 
			
		||||
            rm_infer_request = deepcopy(infer_request)
 | 
			
		||||
            problem = infer_request['messages'][0]['content']
 | 
			
		||||
            start_index = problem.index('</text>')
 | 
			
		||||
            end_index = problem.index('Format your response as follows:')
 | 
			
		||||
            question = problem[start_index:end_index].replace('</text>', '').strip()
 | 
			
		||||
            prompt = self.system.format(
 | 
			
		||||
                problem_placeholder=question, answer1_placeholder=completion, answer2_placeholder=ground_truth)
 | 
			
		||||
 | 
			
		||||
            # Construct new messages tailored for the reward model
 | 
			
		||||
            rm_messages = [{'role': 'user', 'content': prompt}]
 | 
			
		||||
 | 
			
		||||
            # Update the messages in the reward infer request
 | 
			
		||||
            rm_infer_request['messages'] = rm_messages
 | 
			
		||||
            rm_inputs.append(rm_infer_request)
 | 
			
		||||
        return rm_inputs
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def extract_reward(model_output: str) -> float:
 | 
			
		||||
        match = re.search(r'\[([A-Z]+)\]', model_output)
 | 
			
		||||
        if match:
 | 
			
		||||
            answer = match.group(1)
 | 
			
		||||
            if answer == 'YES':
 | 
			
		||||
                return 1.0
 | 
			
		||||
            elif answer == 'NO':
 | 
			
		||||
                return 0.0
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning("Unexpected answer, expected 'YES' or 'NO'.")
 | 
			
		||||
                return 0.0
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning("Unable to extract reward score from the model's output, setting reward to 0")
 | 
			
		||||
            return 0.0  # Or raise ValueError("Format incorrect")
 | 
			
		||||
 | 
			
		||||
    def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
 | 
			
		||||
        """
 | 
			
		||||
        Compute average reward scores from the reward model's outputs.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            results (List[ChatCompletionResponse]): A list of results from the reward model.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List[float]: A list of average reward scores.
 | 
			
		||||
        """
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for idx, output in enumerate(results):
 | 
			
		||||
            try:
 | 
			
		||||
                cur_rewards = []
 | 
			
		||||
                for choice in output.choices:
 | 
			
		||||
                    response = choice.message.content
 | 
			
		||||
                    reward = self.extract_reward(response)
 | 
			
		||||
                    cur_rewards.append(reward)
 | 
			
		||||
                cur_rewards = [r for r in cur_rewards if r is not None]
 | 
			
		||||
                if cur_rewards:
 | 
			
		||||
                    average_reward = sum(cur_rewards) / len(cur_rewards)
 | 
			
		||||
                else:
 | 
			
		||||
                    average_reward = 0.0
 | 
			
		||||
                    logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
 | 
			
		||||
 | 
			
		||||
                rewards.append(average_reward)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f'Error computing reward: {e}')
 | 
			
		||||
                rewards.append(0.0)  # Assign default reward score on failure
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
rm_plugins['my_rmplugin'] = CustomizedRMPlugin
 | 
			
		||||
rm_plugins['qwenlong'] = QwenLongPlugin
 | 
			
		||||
							
								
								
									
										1
									
								
								sample/combine_output_file_3.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								sample/combine_output_file_3.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										633
									
								
								sample/plugin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										633
									
								
								sample/plugin.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,633 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
import re
 | 
			
		||||
import textwrap
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from typing import Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from swift.llm import PtEngine, RequestConfig, Template, to_device
 | 
			
		||||
from swift.llm.infer.protocol import ChatCompletionResponse
 | 
			
		||||
from swift.plugin import ORM, orms, rm_plugins
 | 
			
		||||
from swift.plugin.rm_plugin import DefaultRMPlugin
 | 
			
		||||
from swift.utils import get_logger
 | 
			
		||||
 | 
			
		||||
logger = get_logger()
 | 
			
		||||
"""
 | 
			
		||||
Step 1: Define a Reward Class
 | 
			
		||||
    Implement your custom reward calculation logic within the __call__ method.
 | 
			
		||||
    The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters.
 | 
			
		||||
 | 
			
		||||
Step 2: Register the Reward Class in orms
 | 
			
		||||
    For example:
 | 
			
		||||
    python orms['external_math_acc'] = MathAccuracy
 | 
			
		||||
 | 
			
		||||
Step 3: Configure the Arguments
 | 
			
		||||
    Use the following arguments when running the script:
 | 
			
		||||
    bash --plugin /path/to/plugin.py --reward_funcs external_math_acc
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Code borrowed from plugin/orm.py
 | 
			
		||||
class MathAccuracy(ORM):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        import importlib.util
 | 
			
		||||
        assert importlib.util.find_spec('math_verify') is not None, (
 | 
			
		||||
            "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, solution, **kwargs) -> List[float]:
 | 
			
		||||
        from latex2sympy2_extended import NormalizationConfig
 | 
			
		||||
        from math_verify import LatexExtractionConfig, parse, verify
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for content, sol in zip(completions, solution):
 | 
			
		||||
            gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()])
 | 
			
		||||
            if len(gold_parsed) != 0:
 | 
			
		||||
                # We require the answer to be provided in correct latex (no malformed operators)
 | 
			
		||||
                answer_parsed = parse(
 | 
			
		||||
                    content,
 | 
			
		||||
                    extraction_config=[
 | 
			
		||||
                        LatexExtractionConfig(
 | 
			
		||||
                            normalization_config=NormalizationConfig(
 | 
			
		||||
                                nits=False,
 | 
			
		||||
                                malformed_operators=False,
 | 
			
		||||
                                basic_latex=True,
 | 
			
		||||
                                equations=True,
 | 
			
		||||
                                boxed=True,
 | 
			
		||||
                                units=True,
 | 
			
		||||
                            ),
 | 
			
		||||
                            # Ensures that boxed is tried first
 | 
			
		||||
                            boxed_match_priority=0,
 | 
			
		||||
                            try_extract_without_anchor=False,
 | 
			
		||||
                        )
 | 
			
		||||
                    ],
 | 
			
		||||
                    extraction_mode='first_match',
 | 
			
		||||
                )
 | 
			
		||||
                # Reward 1 if the content is the same as the ground truth, 0 otherwise
 | 
			
		||||
                reward = float(verify(answer_parsed, gold_parsed))
 | 
			
		||||
            else:
 | 
			
		||||
                # If the gold solution is not parseable, we reward 1 to skip this example
 | 
			
		||||
                reward = 1.0
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MathFormat(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        """Reward function that checks if the completion has a specific format."""
 | 
			
		||||
        pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])'
 | 
			
		||||
        matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
 | 
			
		||||
        return [1.0 if match else 0.0 for match in matches]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CountdownORM(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, target, nums, **kwargs) -> List[float]:
 | 
			
		||||
        """
 | 
			
		||||
        Evaluates completions based on Mathematical correctness of the answer
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            completions (list[str]): Generated outputs
 | 
			
		||||
            target (list[str]): Expected answers
 | 
			
		||||
            nums (list[str]): Available numbers
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[float]: Reward scores
 | 
			
		||||
        """
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for completion, gt, numbers in zip(completions, target, nums):
 | 
			
		||||
            try:
 | 
			
		||||
                # Check if the format is correct
 | 
			
		||||
                match = re.search(r'<answer>(.*?)<\/answer>', completion)
 | 
			
		||||
                if match is None:
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
                    continue
 | 
			
		||||
                # Extract the "answer" part from the completion
 | 
			
		||||
                equation = match.group(1).strip()
 | 
			
		||||
                if '=' in equation:
 | 
			
		||||
                    equation = equation.split('=')[0]
 | 
			
		||||
                # Extract all numbers from the equation
 | 
			
		||||
                used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
 | 
			
		||||
 | 
			
		||||
                # Check if all numbers are used exactly once
 | 
			
		||||
                if sorted(used_numbers) != sorted(numbers):
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
                    continue
 | 
			
		||||
                # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
 | 
			
		||||
                allowed_pattern = r'^[\d+\-*/().\s]+$'
 | 
			
		||||
                if not re.match(allowed_pattern, equation):
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # Evaluate the equation with restricted globals and locals
 | 
			
		||||
                result = eval(equation, {"__builti'ns__": None}, {})
 | 
			
		||||
                # Check if the equation is correct and matches the ground truth
 | 
			
		||||
                if abs(float(result) - float(gt)) < 1e-5:
 | 
			
		||||
                    rewards.append(1.0)
 | 
			
		||||
                else:
 | 
			
		||||
                    rewards.append(0.0)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                # If evaluation fails, reward is 0
 | 
			
		||||
                rewards.append(0.0)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MultiModalAccuracyORM(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, solution, **kwargs) -> List[float]:
 | 
			
		||||
        """
 | 
			
		||||
        Reward function that checks if the completion is correct.
 | 
			
		||||
        Args:
 | 
			
		||||
            completions (list[str]): Generated outputs
 | 
			
		||||
            solution (list[str]): Ground Truths.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[float]: Reward scores
 | 
			
		||||
        """
 | 
			
		||||
        rewards = []
 | 
			
		||||
        from math_verify import parse, verify
 | 
			
		||||
        for content, sol in zip(completions, solution):
 | 
			
		||||
            reward = 0.0
 | 
			
		||||
            # Try symbolic verification first
 | 
			
		||||
            try:
 | 
			
		||||
                answer = parse(content)
 | 
			
		||||
                if float(verify(answer, parse(sol))) > 0:
 | 
			
		||||
                    reward = 1.0
 | 
			
		||||
            except Exception:
 | 
			
		||||
                pass  # Continue to next verification method if this fails
 | 
			
		||||
 | 
			
		||||
            # If symbolic verification failed, try string matching
 | 
			
		||||
            if reward == 0.0:
 | 
			
		||||
                try:
 | 
			
		||||
                    # Extract answer from solution if it has think/answer tags
 | 
			
		||||
                    sol_match = re.search(r'<answer>(.*?)</answer>', sol)
 | 
			
		||||
                    ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
 | 
			
		||||
 | 
			
		||||
                    # Extract answer from content if it has think/answer tags
 | 
			
		||||
                    content_match = re.search(r'<answer>(.*?)</answer>', content)
 | 
			
		||||
                    student_answer = content_match.group(1).strip() if content_match else content.strip()
 | 
			
		||||
 | 
			
		||||
                    # Compare the extracted answers
 | 
			
		||||
                    if student_answer == ground_truth:
 | 
			
		||||
                        reward = 1.0
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    pass  # Keep reward as 0.0 if both methods fail
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ref implementation: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py
 | 
			
		||||
class CodeReward(ORM):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        import importlib.util
 | 
			
		||||
        assert importlib.util.find_spec('e2b') is not None, (
 | 
			
		||||
            "The e2b package is required but not installed. Please install it using 'pip install e2b-code-interpreter'."
 | 
			
		||||
        )
 | 
			
		||||
        from dotenv import load_dotenv
 | 
			
		||||
        load_dotenv()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def extract_code(completion: str, language: str) -> str:
 | 
			
		||||
        pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
 | 
			
		||||
        matches = pattern.findall(completion)
 | 
			
		||||
        extracted_answer = matches[-1] if len(matches) >= 1 else ''
 | 
			
		||||
        return extracted_answer
 | 
			
		||||
 | 
			
		||||
    def run_async_from_sync(self, scripts: List[str], languages: List[str]) -> List[float]:
 | 
			
		||||
        """Function wrapping the `run_async` function."""
 | 
			
		||||
        # Create a new event loop and set it
 | 
			
		||||
        loop = asyncio.new_event_loop()
 | 
			
		||||
        asyncio.set_event_loop(loop)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # Run the async function and get the result
 | 
			
		||||
            rewards = loop.run_until_complete(self.run_async(scripts, languages))
 | 
			
		||||
        finally:
 | 
			
		||||
            loop.close()
 | 
			
		||||
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    async def run_async(self, scripts: List[str], languages: List[str]) -> List[float]:
 | 
			
		||||
        from e2b_code_interpreter import AsyncSandbox
 | 
			
		||||
 | 
			
		||||
        # Create the sandbox by hand, currently there's no context manager for this version
 | 
			
		||||
        try:
 | 
			
		||||
            sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from E2B executor: {e}')
 | 
			
		||||
            return [0.0] * len(scripts)
 | 
			
		||||
        # Create a list of tasks for running scripts concurrently
 | 
			
		||||
        tasks = [self.run_script(sbx, script, language) for script, language in zip(scripts, languages)]
 | 
			
		||||
 | 
			
		||||
        # Wait for all tasks to complete and gather their results as they finish
 | 
			
		||||
        results = await asyncio.gather(*tasks)
 | 
			
		||||
        rewards = list(results)  # collect results
 | 
			
		||||
 | 
			
		||||
        # Kill the sandbox after all the tasks are complete
 | 
			
		||||
        await sbx.kill()
 | 
			
		||||
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    async def run_script(self, sbx, script: str, language: str) -> float:
 | 
			
		||||
        try:
 | 
			
		||||
            execution = await sbx.run_code(script, language=language, timeout=30)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from E2B executor: {e}')
 | 
			
		||||
            return 0.0
 | 
			
		||||
        try:
 | 
			
		||||
            return float(execution.text)
 | 
			
		||||
        except (TypeError, ValueError):
 | 
			
		||||
            return 0.0
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        """Reward function that evaluates code snippets using the E2B code interpreter.
 | 
			
		||||
 | 
			
		||||
        Assumes the dataset contains a `verification_info` column with test cases.
 | 
			
		||||
        """
 | 
			
		||||
        evaluation_script_template = """
 | 
			
		||||
        import subprocess
 | 
			
		||||
        import json
 | 
			
		||||
 | 
			
		||||
        def evaluate_code(code, test_cases):
 | 
			
		||||
            passed = 0
 | 
			
		||||
            total = len(test_cases)
 | 
			
		||||
            exec_timeout = 5
 | 
			
		||||
 | 
			
		||||
            for case in test_cases:
 | 
			
		||||
                process = subprocess.run(
 | 
			
		||||
                    ["python3", "-c", code],
 | 
			
		||||
                    input=case["input"],
 | 
			
		||||
                    text=True,
 | 
			
		||||
                    capture_output=True,
 | 
			
		||||
                    timeout=exec_timeout
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if process.returncode != 0:  # Error in execution
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                output = process.stdout.strip()
 | 
			
		||||
                if output.strip() == case["output"].strip():
 | 
			
		||||
                    passed += 1
 | 
			
		||||
 | 
			
		||||
            success_rate = (passed / total)
 | 
			
		||||
            return success_rate
 | 
			
		||||
 | 
			
		||||
        code_snippet = {code}
 | 
			
		||||
        test_cases = json.loads({test_cases})
 | 
			
		||||
 | 
			
		||||
        evaluate_code(code_snippet, test_cases)
 | 
			
		||||
        """
 | 
			
		||||
        verification_info = kwargs['verification_info']
 | 
			
		||||
        languages = [info['language'] for info in verification_info]
 | 
			
		||||
        code_snippets = [
 | 
			
		||||
            self.extract_code(completion, language) for completion, language in zip(completions, languages)
 | 
			
		||||
        ]
 | 
			
		||||
        scripts = [
 | 
			
		||||
            evaluation_script_template.format(
 | 
			
		||||
                code=json.dumps(code), test_cases=json.dumps(json.dumps(info['test_cases'])))
 | 
			
		||||
            for code, info in zip(code_snippets, verification_info)
 | 
			
		||||
        ]
 | 
			
		||||
        try:
 | 
			
		||||
            rewards = self.run_async_from_sync(scripts, languages)
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from E2B executor: {e}')
 | 
			
		||||
            rewards = [0.0] * len(completions)
 | 
			
		||||
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CodeFormat(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        verification_info = kwargs['verification_info']
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for content, info in zip(completions, verification_info):
 | 
			
		||||
            pattern = r'^<think>.*?</think>\s*<answer>.*?```{}.*?```.*?</answer>(?![\s\S])'.format(info['language'])
 | 
			
		||||
            match = re.match(pattern, content, re.DOTALL | re.MULTILINE)
 | 
			
		||||
            reward = 1.0 if match else 0.0
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CodeRewardByJudge0(ORM):
 | 
			
		||||
    LANGUAGE_ID_MAP = {
 | 
			
		||||
        'assembly': 45,
 | 
			
		||||
        'bash': 46,
 | 
			
		||||
        'basic': 47,
 | 
			
		||||
        'c': 50,
 | 
			
		||||
        'c++': 54,
 | 
			
		||||
        'clojure': 86,
 | 
			
		||||
        'c#': 51,
 | 
			
		||||
        'cobol': 77,
 | 
			
		||||
        'common lisp': 55,
 | 
			
		||||
        'd': 56,
 | 
			
		||||
        'elixir': 57,
 | 
			
		||||
        'erlang': 58,
 | 
			
		||||
        'executable': 44,
 | 
			
		||||
        'f#': 87,
 | 
			
		||||
        'fortran': 59,
 | 
			
		||||
        'go': 60,
 | 
			
		||||
        'groovy': 88,
 | 
			
		||||
        'haskell': 61,
 | 
			
		||||
        'java': 62,
 | 
			
		||||
        'javascript': 63,
 | 
			
		||||
        'kotlin': 78,
 | 
			
		||||
        'lua': 64,
 | 
			
		||||
        'multi-file program': 89,
 | 
			
		||||
        'objective-c': 79,
 | 
			
		||||
        'ocaml': 65,
 | 
			
		||||
        'octave': 66,
 | 
			
		||||
        'pascal': 67,
 | 
			
		||||
        'perl': 85,
 | 
			
		||||
        'php': 68,
 | 
			
		||||
        'plain text': 43,
 | 
			
		||||
        'prolog': 69,
 | 
			
		||||
        'python': 71,
 | 
			
		||||
        'python2': 70,
 | 
			
		||||
        'python3': 71,
 | 
			
		||||
        'r': 80,
 | 
			
		||||
        'ruby': 72,
 | 
			
		||||
        'rust': 73,
 | 
			
		||||
        'scala': 81,
 | 
			
		||||
        'sql': 82,
 | 
			
		||||
        'swift': 83,
 | 
			
		||||
        'typescript': 74,
 | 
			
		||||
        'visual basic.net': 84
 | 
			
		||||
    }
 | 
			
		||||
    PYTHON_ID = 71
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        import os
 | 
			
		||||
        self.endpoint = os.getenv('JUDGE0_ENDPOINT')
 | 
			
		||||
        assert self.endpoint is not None, (
 | 
			
		||||
            'Judge0 endpoint is not set. Please set the JUDGE0_ENDPOINT environment variable.')
 | 
			
		||||
        x_auth_token = os.getenv('JUDGE0_X_AUTH_TOKEN')
 | 
			
		||||
        self.headers = {'Content-Type': 'application/json'}
 | 
			
		||||
        if x_auth_token is not None:
 | 
			
		||||
            self.headers['X-Auth-Token'] = x_auth_token
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def extract_code(completion: str, language: str) -> str:
 | 
			
		||||
        pattern = re.compile(rf'```{language}\n(.*?)```', re.DOTALL)
 | 
			
		||||
        matches = pattern.findall(completion)
 | 
			
		||||
        extracted_answer = matches[-1] if len(matches) >= 1 else ''
 | 
			
		||||
        return extracted_answer
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_language_id(cls, language):
 | 
			
		||||
        if language is None:
 | 
			
		||||
            return cls.PYTHON_ID
 | 
			
		||||
        return cls.LANGUAGE_ID_MAP.get(language.lower().strip(), cls.PYTHON_ID)
 | 
			
		||||
 | 
			
		||||
    async def _evaluate_code(self, code, test_cases, language_id):
 | 
			
		||||
        import aiohttp
 | 
			
		||||
        try:
 | 
			
		||||
            passed = 0
 | 
			
		||||
            total = len(test_cases)
 | 
			
		||||
 | 
			
		||||
            for case in test_cases:
 | 
			
		||||
                if code is not None and code != '':
 | 
			
		||||
                    async with aiohttp.ClientSession() as session:
 | 
			
		||||
                        payload = {
 | 
			
		||||
                            'source_code': code,
 | 
			
		||||
                            'language_id': language_id,
 | 
			
		||||
                            'stdin': case['input'],
 | 
			
		||||
                            'expected_output': case['output']
 | 
			
		||||
                        }
 | 
			
		||||
                        logger.debug(f'Payload: {payload}')
 | 
			
		||||
                        async with session.post(
 | 
			
		||||
                                self.endpoint + '/submissions/?wait=true', json=payload,
 | 
			
		||||
                                headers=self.headers) as response:
 | 
			
		||||
                            response_json = await response.json()
 | 
			
		||||
                            logger.debug(f'Response: {response_json}')
 | 
			
		||||
                            if response_json['status']['description'] == 'Accepted':
 | 
			
		||||
                                passed += 1
 | 
			
		||||
 | 
			
		||||
            success_rate = (passed / total)
 | 
			
		||||
            return success_rate
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from Judge0 executor: {e}')
 | 
			
		||||
            return 0.0
 | 
			
		||||
 | 
			
		||||
    def run_async_from_sync(self):
 | 
			
		||||
        loop = asyncio.new_event_loop()
 | 
			
		||||
        asyncio.set_event_loop(loop)
 | 
			
		||||
        try:
 | 
			
		||||
            rewards = loop.run_until_complete(self.run_async())
 | 
			
		||||
        finally:
 | 
			
		||||
            loop.close()
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    async def run_async(self):
 | 
			
		||||
        tasks = [
 | 
			
		||||
            self._evaluate_code(code, info['test_cases'], CodeRewardByJudge0.get_language_id(info['language']))
 | 
			
		||||
            for code, info in zip(self.code_snippets, self.verification_info)
 | 
			
		||||
        ]
 | 
			
		||||
        results = await asyncio.gather(*tasks)
 | 
			
		||||
        rewards = list(results)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, **kwargs) -> List[float]:
 | 
			
		||||
        self.verification_info = kwargs['verification_info']
 | 
			
		||||
 | 
			
		||||
        languages = [info['language'] for info in self.verification_info]
 | 
			
		||||
        self.code_snippets = [
 | 
			
		||||
            self.extract_code(completion, language) for completion, language in zip(completions, languages)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            rewards = self.run_async_from_sync()
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.warning(f'Error from Judge0 executor: {e}')
 | 
			
		||||
            rewards = [0.0] * len(completions)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
orms['external_math_acc'] = MathAccuracy
 | 
			
		||||
orms['external_math_format'] = MathFormat
 | 
			
		||||
orms['external_countdown'] = CountdownORM
 | 
			
		||||
orms['external_r1v_acc'] = MultiModalAccuracyORM
 | 
			
		||||
orms['external_code_reward'] = CodeReward
 | 
			
		||||
orms['external_code_format'] = CodeFormat
 | 
			
		||||
orms['external_code_reward_by_judge0'] = CodeRewardByJudge0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# For genrm you can refer to swift/llm/plugin/rm_plugin/GenRMPlugin
 | 
			
		||||
class CustomizedRMPlugin:
 | 
			
		||||
    """
 | 
			
		||||
    Customized Reward Model Plugin, same to DefaultRMPlugin
 | 
			
		||||
 | 
			
		||||
    It assumes that `self.model` is a classification model with a value head(output dimmension 1).
 | 
			
		||||
    The first logits value from the model's output is used as the reward score.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model, template):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.template: Template = template
 | 
			
		||||
 | 
			
		||||
    def __call__(self, inputs):
 | 
			
		||||
        batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
 | 
			
		||||
        reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
 | 
			
		||||
        reward_inputs.pop('labels')
 | 
			
		||||
 | 
			
		||||
        with torch.inference_mode():
 | 
			
		||||
            return self.model(**reward_inputs).logits[:, 0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QwenLongPlugin(DefaultRMPlugin):
 | 
			
		||||
    # https://arxiv.org/abs/2505.17667
 | 
			
		||||
    # NOTE: you should customize the verified reward function, you can refer to
 | 
			
		||||
    # https://github.com/Tongyi-Zhiwen/QwenLong-L1/tree/main/verl/verl/utils/reward_score
 | 
			
		||||
    # hf_dataset: https://huggingface.co/datasets/Tongyi-Zhiwen/DocQA-RL-1.6K/viewer/default/train
 | 
			
		||||
    # ms_dataset: https://modelscope.cn/datasets/iic/DocQA-RL-1.6K
 | 
			
		||||
    def __init__(self, model, template, accuracy_orm=None):
 | 
			
		||||
        super().__init__(model, template)
 | 
			
		||||
        # initilize PTEngine to infer
 | 
			
		||||
        self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0)  # 0: no limit
 | 
			
		||||
        self.request_config = RequestConfig(temperature=0)  # customise your request config here
 | 
			
		||||
        self.system = textwrap.dedent("""
 | 
			
		||||
            You are an expert in verifying if two answers are the same.
 | 
			
		||||
 | 
			
		||||
            Your input consists of a problem and two answers: Answer 1 and Answer 2.
 | 
			
		||||
            You need to check if they are equivalent.
 | 
			
		||||
 | 
			
		||||
            Your task is to determine if the two answers are equivalent, without attempting to solve the original problem.
 | 
			
		||||
            Compare the answers to verify they represent identical values or meanings,
 | 
			
		||||
            even when expressed in different forms or notations.
 | 
			
		||||
 | 
			
		||||
            Your output must follow this format:
 | 
			
		||||
            1) Provide an explanation for why the answers are equivalent or not.
 | 
			
		||||
            2) Then provide your final answer in the form of: [[YES]] or [[NO]]
 | 
			
		||||
 | 
			
		||||
            Problem: {problem_placeholder}
 | 
			
		||||
            Answer 1: {answer1_placeholder}
 | 
			
		||||
            Answer 2: {answer2_placeholder}
 | 
			
		||||
        """)  # noqa
 | 
			
		||||
        self.accuracy_orm = accuracy_orm
 | 
			
		||||
 | 
			
		||||
    def __call__(self, inputs):
 | 
			
		||||
        completions = [example['messages'][-1]['content'] for example in inputs]
 | 
			
		||||
        ground_truths = [example['reward_model']['ground_truth'] for example in inputs]
 | 
			
		||||
        rm_inputs = self.prepare_rm_inputs(inputs, completions, ground_truths)
 | 
			
		||||
 | 
			
		||||
        results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
 | 
			
		||||
        llm_rewards = self.compute_rewards(results)
 | 
			
		||||
 | 
			
		||||
        if self.accuracy_orm:
 | 
			
		||||
            verified_rewards = self.accuracy_orm(completions, ground_truths)
 | 
			
		||||
        else:
 | 
			
		||||
            verified_rewards = [0.0] * len(llm_rewards)
 | 
			
		||||
 | 
			
		||||
        rewards = [max(r1, r2) for r1, r2 in zip(llm_rewards, verified_rewards)]
 | 
			
		||||
        return torch.tensor(rewards, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    def prepare_rm_inputs(self, inputs: List[Dict], completions, ground_truths) -> List[Dict]:
 | 
			
		||||
        rm_inputs = []
 | 
			
		||||
        for infer_request, completion, ground_truth in zip(inputs, completions, ground_truths):
 | 
			
		||||
            # Deep copy to prevent modification of original input
 | 
			
		||||
            rm_infer_request = deepcopy(infer_request)
 | 
			
		||||
            problem = infer_request['messages'][0]['content']
 | 
			
		||||
            start_index = problem.index('</text>')
 | 
			
		||||
            end_index = problem.index('Format your response as follows:')
 | 
			
		||||
            question = problem[start_index:end_index].replace('</text>', '').strip()
 | 
			
		||||
            prompt = self.system.format(
 | 
			
		||||
                problem_placeholder=question, answer1_placeholder=completion, answer2_placeholder=ground_truth)
 | 
			
		||||
 | 
			
		||||
            # Construct new messages tailored for the reward model
 | 
			
		||||
            rm_messages = [{'role': 'user', 'content': prompt}]
 | 
			
		||||
 | 
			
		||||
            # Update the messages in the reward infer request
 | 
			
		||||
            rm_infer_request['messages'] = rm_messages
 | 
			
		||||
            rm_inputs.append(rm_infer_request)
 | 
			
		||||
        return rm_inputs
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def extract_reward(model_output: str) -> float:
 | 
			
		||||
        match = re.search(r'\[([A-Z]+)\]', model_output)
 | 
			
		||||
        if match:
 | 
			
		||||
            answer = match.group(1)
 | 
			
		||||
            if answer == 'YES':
 | 
			
		||||
                return 1.0
 | 
			
		||||
            elif answer == 'NO':
 | 
			
		||||
                return 0.0
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning("Unexpected answer, expected 'YES' or 'NO'.")
 | 
			
		||||
                return 0.0
 | 
			
		||||
        else:
 | 
			
		||||
            logger.warning("Unable to extract reward score from the model's output, setting reward to 0")
 | 
			
		||||
            return 0.0  # Or raise ValueError("Format incorrect")
 | 
			
		||||
 | 
			
		||||
    def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
 | 
			
		||||
        """
 | 
			
		||||
        Compute average reward scores from the reward model's outputs.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            results (List[ChatCompletionResponse]): A list of results from the reward model.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            List[float]: A list of average reward scores.
 | 
			
		||||
        """
 | 
			
		||||
        rewards = []
 | 
			
		||||
        for idx, output in enumerate(results):
 | 
			
		||||
            try:
 | 
			
		||||
                cur_rewards = []
 | 
			
		||||
                for choice in output.choices:
 | 
			
		||||
                    response = choice.message.content
 | 
			
		||||
                    reward = self.extract_reward(response)
 | 
			
		||||
                    cur_rewards.append(reward)
 | 
			
		||||
                cur_rewards = [r for r in cur_rewards if r is not None]
 | 
			
		||||
                if cur_rewards:
 | 
			
		||||
                    average_reward = sum(cur_rewards) / len(cur_rewards)
 | 
			
		||||
                else:
 | 
			
		||||
                    average_reward = 0.0
 | 
			
		||||
                    logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
 | 
			
		||||
 | 
			
		||||
                rewards.append(average_reward)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error(f'Error computing reward: {e}')
 | 
			
		||||
                rewards.append(0.0)  # Assign default reward score on failure
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
rm_plugins['my_rmplugin'] = CustomizedRMPlugin
 | 
			
		||||
rm_plugins['qwenlong'] = QwenLongPlugin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WebAccuracy(ORM):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, completions, ground_truths, **kwargs) -> List[float]:
 | 
			
		||||
        solution = ground_truths
 | 
			
		||||
        rewards = []
 | 
			
		||||
        solution_pattern = r'\b(do|exit|go_backward|go_forward)\(.*\)'
 | 
			
		||||
        answer_pattern = r'<answer>(.*?)</answer>'
 | 
			
		||||
 | 
			
		||||
        for content, sol in zip(completions, solution):
 | 
			
		||||
            reward = 0.0
 | 
			
		||||
 | 
			
		||||
            # 1. Extract reference from solution
 | 
			
		||||
            sol_match = re.search(solution_pattern, sol, re.DOTALL)
 | 
			
		||||
            if sol_match:
 | 
			
		||||
                reference = sol_match.group(0).strip()
 | 
			
		||||
 | 
			
		||||
                answer_text = content.messages[-1]['content']
 | 
			
		||||
                # 2. Extract answer from completion
 | 
			
		||||
                answer_match = re.search(answer_pattern, answer_text, re.DOTALL)
 | 
			
		||||
                if answer_match:
 | 
			
		||||
                    answer_text = answer_match.group(1).strip()
 | 
			
		||||
 | 
			
		||||
                    # get solution form
 | 
			
		||||
                    solution_match = re.search(solution_pattern, answer_text, re.DOTALL)
 | 
			
		||||
                    if solution_match:
 | 
			
		||||
                        solution_text = solution_match.group(0).strip()
 | 
			
		||||
                        if reference in solution_text:
 | 
			
		||||
                            reward = 1.0
 | 
			
		||||
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
        return rewards
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
orms['external_web_acc'] = WebAccuracy
 | 
			
		||||
							
								
								
									
										14
									
								
								sample/sample_test.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								sample/sample_test.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
MODEL="qwen3-8b"
 | 
			
		||||
 | 
			
		||||
OPENAI_API_KEY="sk-a6a00cb727ee4913ba22530ec3c2b30d" \
 | 
			
		||||
swift sample \
 | 
			
		||||
    --sampler_type distill \
 | 
			
		||||
    --sampler_engine client \
 | 
			
		||||
    --model $MODEL \
 | 
			
		||||
    --stream true \
 | 
			
		||||
    --orm_model math \
 | 
			
		||||
    --dataset AI-MO/NuminaMath-TIR#5 \
 | 
			
		||||
    --num_return_sequences 1 \
 | 
			
		||||
    --temperature 0.6 \
 | 
			
		||||
    --top_p 0.95 \
 | 
			
		||||
    --engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
 | 
			
		||||
							
								
								
									
										15
									
								
								sample/sample_test_web.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								sample/sample_test_web.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,15 @@
 | 
			
		||||
MODEL="qwen3-8b"
 | 
			
		||||
 | 
			
		||||
OPENAI_API_KEY="sk-a6a00cb727ee4913ba22530ec3c2b30d" \
 | 
			
		||||
swift sample \
 | 
			
		||||
    --sampler_type distill \
 | 
			
		||||
    --sampler_engine client \
 | 
			
		||||
    --model $MODEL \
 | 
			
		||||
    --stream true \
 | 
			
		||||
    --orm_model external_web_acc \
 | 
			
		||||
    --dataset combine_output_file_3.jsonl \
 | 
			
		||||
    --num_return_sequences 1 \
 | 
			
		||||
    --temperature 0.6 \
 | 
			
		||||
    --top_p 0.95 \
 | 
			
		||||
    --external_plugins plugin.py \
 | 
			
		||||
    --engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user