From 290f906591368839e74450a7ce667c7dc9a114d2 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 13 Nov 2025 08:13:25 -0800 Subject: [PATCH 1/2] first --- .../tutorials/1_tau2bench_overview.md | 314 ++++++++++++++++++ .../tutorials/2_fundamentals.md | 235 +++++++++++++ .../tutorials/3_forge_current_state.md | 271 +++++++++++++++ 3 files changed, 820 insertions(+) create mode 100644 brainstorming_forge_tau/tutorials/1_tau2bench_overview.md create mode 100644 brainstorming_forge_tau/tutorials/2_fundamentals.md create mode 100644 brainstorming_forge_tau/tutorials/3_forge_current_state.md diff --git a/brainstorming_forge_tau/tutorials/1_tau2bench_overview.md b/brainstorming_forge_tau/tutorials/1_tau2bench_overview.md new file mode 100644 index 000000000..8fa665a90 --- /dev/null +++ b/brainstorming_forge_tau/tutorials/1_tau2bench_overview.md @@ -0,0 +1,314 @@ +# Part 1: Tau2Bench Overview - What Are We Building For? + +## 1.1 What is Tau2Bench? + +**Reference**: `tau2-bench/README.md`, `tau2-bench/src/tau2/evaluator/evaluator.py` + +Tau2Bench is a benchmark for evaluating conversational agents in customer service scenarios. It tests whether your RL-trained model can: +- Follow domain policies correctly +- Use tools appropriately (search databases, update records, etc.) +- Communicate effectively with users + +Example task: "Create a task called 'Important Meeting' for user_1 with description 'Quarterly planning' and deadline tomorrow." + +The agent must call `create_task(user_id="user_1", title="Important Meeting", ...)` with the right parameters, then confirm to the user. + +## 1.2 Tau2 Modes + +**Reference**: `tau2-bench/src/tau2/orchestrator.py:67-174` + +**Solo Mode** (Recommended for training): +- Agent works alone on tickets/tasks +- No user interaction +- Simpler, deterministic +- Use this for initial training + +**Normal Mode**: +- Agent + User Simulator (LLM playing customer) +- More realistic but harder + +## 1.3 Tau2 Task Structure + +**Reference**: Task files at `tau2-bench/data/tau2/domains/{domain}/tasks.json`, data model at `tau2-bench/src/tau2/data_model/tasks.py` + +Tasks are defined in JSON format: + +```json +{ + "id": "create_task_1", + "ticket": "User wants to create a task titled 'Important Meeting' for user_1", + "evaluation_criteria": { + "actions": [ + { + "action_id": "create_1", + "name": "create_task", + "arguments": { + "user_id": "user_1", + "title": "Important Meeting" + } + } + ], + "reward_basis": ["ACTION", "COMMUNICATE"] + } +} +``` + +Key fields: +- `ticket`: Initial task description +- `evaluation_criteria.actions`: Expected tool calls +- `reward_basis`: What to score (ACTION, ENV, COMMUNICATE, NL_ASSERTIONS) + +**NOTE ON EVAL**: In this case, evaluation is checking if the tool was called. In other cases, it may be having another LLM verify if the task was completed correctly. + +## 1.4 Tau2 Available Tools (Mock Domain) + +```python +# Mock domain tools for demonstration +tools = [ + { + "name": "create_task", + "description": "Create a new task", + "parameters": { + "user_id": "string", + "title": "string", + "description": "string (optional)", + "deadline": "string (optional)" + } + }, + { + "name": "update_task", + "description": "Update an existing task", + "parameters": { + "task_id": "string", + "status": "string (pending|completed|cancelled)" + } + }, + { + "name": "done", + "description": "Signal task completion", + "parameters": {} + } +] +``` + +**Production Domains**: Tau2Bench includes three main production domains with domain-specific tools, policies, and databases: +- **Airline**: Flight booking, modifications, cancellations (`tau2-bench/src/tau2/domains/airline/`) +- **Retail**: Product orders, returns, exchanges (`tau2-bench/src/tau2/domains/retail/`) +- **Telecom**: Technical support, bill payments, line management (`tau2-bench/src/tau2/domains/telecom/`) + +## 1.5 Example Multi-turn Interaction on Tau2 + +**Solo Mode Example:** + +``` +Turn 1: +Agent: Let me create that task for you. + create_task(user_id="user_1", title="Important Meeting", + description="Quarterly planning", deadline="2024-01-16") +Env: Task created with ID: task_123 + +Turn 2: +Agent: Task created successfully. Is there anything else you need? + done() +Env: Episode complete. +``` + +**Note**: `done()` signals episode end. In Normal Mode, users can also end with keywords like "bye", "thanks" (see `tau2-bench/src/tau2/orchestrator.py:171-174` for stop conditions) + +## 1.6 How Tau2 Scores Episodes + +**Reference**: Evaluation logic in `tau2-bench/src/tau2/evaluator/evaluator.py`, metrics in `tau2-bench/src/tau2/metrics/agent_metrics.py` + +Tau2Bench computes rewards based on multiple criteria: + +**1. ACTION Score** (0.0 or 1.0): +- Did agent call the right tools? +- With the right arguments (or subset via `compare_args`)? +- Order doesn't matter + +**2. ENV Score** (0.0 or 1.0): +- Is environment state correct? +- Database checks (e.g., task_id="task_2" has status="pending") + +**3. COMMUNICATE Score** (0.0 or 1.0): +- Did agent communicate required information to user? + +**4. NL_ASSERTIONS Score** (0.0 or 1.0): +- LLM-based evaluation of conversation quality (experimental) + +**Final Reward:** +```python +final_reward = ACTION_score * ENV_score * COMMUNICATE_score * NL_ASSERTIONS_score +``` + +**CRITICAL**: Episode must end with either: +- `AGENT_STOP`: Agent calls `done()` tool +- `USER_STOP`: User says stop keywords + +Otherwise: `reward = 0.0` regardless of actions! + +**Sparse Rewards**: You only get the final reward at episode end. Intermediate tool calls get `reward=0.0`. + +--- + +## 1.7 Tau2Bench Production Domains + +Tau2Bench includes three production-ready customer service domains. Each domain has its own policy, tools, database, and evaluation tasks. + +### Airline Domain + +**Location**: `tau2-bench/data/tau2/domains/airline/` +- **Tasks**: 50 tasks in `tasks.json` +- **Policy**: `policy.md` +- **Code**: `tau2-bench/src/tau2/domains/airline/tools.py` + +**What agents do**: Book, modify, and cancel flight reservations, handle refunds and compensation, manage baggage and travel insurance. + +**Example tasks**: +- Cancellation policy testing (refuse invalid cancellations) +- Membership verification for baggage allowance +- Compensation fraud detection +- Complex modifications (multiple changes at once) +- Multi-reservation management + +**Available tools**: +- `get_user_details()`, `get_reservation_details()` +- `search_flights()`, `book_flight()`, `modify_flight()`, `cancel_reservation()` +- `add_baggage()`, `get_compensation()` +- `transfer_to_human_agents()` + +**Key policy rules**: +- Basic economy flights cannot be modified after booking +- Cancellations only allowed if: within 24hrs of booking, airline cancelled, business flight, or insurance covers reason +- Max 24 hours confirmation required before database-modifying actions +- Travel insurance: $30/passenger, enables full refund for covered reasons + +**Rewards**: DB checks, ENV_ASSERTION, ACTION-based evaluation + +### Retail Domain + +**Location**: `tau2-bench/data/tau2/domains/retail/` +- **Tasks**: 114 tasks in `tasks.json` +- **Policy**: `policy.md` +- **Code**: `tau2-bench/src/tau2/domains/retail/tools.py` + +**What agents do**: Help customers return/exchange delivered orders, cancel/modify pending orders, manage payment methods and addresses, provide product information. + +**Example tasks**: +- Multi-item exchanges with specific options +- Conditional exchanges (fallback options if unavailable) +- Product information queries + multiple returns +- Pending order modifications (change color, material, etc.) +- Cross-order refunds (complex refunds across multiple orders) +- Selective returns (specific items from orders) +- Address modifications for pending orders + +**Available tools**: +- `find_user_id_by_name_zip()`, `find_user_id_by_email()` +- `get_order_details()`, `get_product_details()` +- `cancel_pending_order()`, `modify_pending_order_items()` +- `return_delivered_order_items()`, `exchange_delivered_order_items()` +- `modify_pending_order_payment()`, `modify_user_default_address()` +- `transfer_to_human_agents()` + +**Key policy rules**: +- User authentication required via email OR name+zip before any action +- Pending orders can only be cancelled/modified once +- Delivered orders can be returned or exchanged +- Product IDs ≠ Item IDs (must distinguish between catalog and specific variants) +- One order modification max - collect all changes before calling tool +- Product variants: Different options (color, size, material) = different item_ids +- Refunds: Gift card refunds immediate, others 5-7 business days + +**Rewards**: DB checks, ACTION-based, COMMUNICATE evaluation + +### Telecom Domain + +**Location**: `tau2-bench/data/tau2/domains/telecom/` +- **Tasks**: 2,285 tasks in `tasks.json` (many auto-generated variants) +- **Policy**: `main_policy.md` +- **Code**: `tau2-bench/src/tau2/domains/telecom/tools.py` (agent) and `user_tools.py` (simulator) + +**What agents do**: Provide technical support for mobile devices and connectivity issues, handle overdue bill payments, manage line suspensions, help with data refueling and plan changes. + +**Example task categories**: +- **Mobile data issues** (~1000+ tasks): Roaming problems, data mode issues, network preference problems, VPN connectivity, airplane mode interference, data usage exceeded, multiple combined issues +- **MMS issues**: MMS sending failures with various device states +- **Service issues**: Line suspension problems, network outages, connection problems + +**Example task IDs**: +- `[mobile_data_issue]user_abroad_roaming_enabled_off[PERSONA:None]` - User abroad with roaming disabled +- `[mobile_data_issue]data_usage_exceeded[PERSONA:Easy]` - User exceeded data limit +- `[mobile_data_issue]airplane_mode_on|data_saver_mode_on[PERSONA:Easy]` - Multiple issues combined + +**Available agent tools**: +- `get_customer_by_phone()`, `get_customer_by_id()`, `get_customer_by_name()` +- `get_line()`, `get_line_by_phone()`, `get_bill()`, `get_bills_by_customer()` +- `send_payment_request()`, `make_payment()` +- `refuel_data()` (max 2GB), `change_plan()` +- `suspend_line()`, `resume_line()` +- `transfer_to_human_agents()` + +**Unique user tools** (simulates user controlling device): +- `set_user_location()`, `toggle_roaming()`, `toggle_airplane_mode()`, `toggle_mobile_data()` +- `toggle_data_saver_mode()`, `set_network_preference()`, `toggle_vpn()`, `toggle_eSIM()` +- `perform_speed_test()`, `get_status_bar()`, `can_send_mms()` + +**Key policy rules**: +- Try to resolve before escalating to human agents +- Overdue bills: Check status → send payment request → customer checks request → make payment +- Line suspension: Only lift after all overdue bills paid (cannot lift for expired contracts) +- Data refueling: Max 2GB per refuel, price varies by plan +- Customer lookup: By phone, ID, or name+DOB +- Bill status types: Draft, Issued, Paid, Overdue, Awaiting Payment, Disputed +- Line status types: Active, Suspended, Pending Activation, Closed + +**Rewards**: ENV_ASSERTION (checks device state), ACTION (correct tool calls), COMMUNICATE + +**Example telecom evaluation**: +```json +{ + "actions": [{"name": "toggle_roaming", "requestor": "user"}], + "env_assertions": [ + {"func_name": "assert_mobile_data_status", "expected_status": true}, + {"func_name": "assert_internet_speed", "expected_desc": "excellent"} + ], + "reward_basis": ["ENV_ASSERTION"] +} +``` + +Success = Agent correctly diagnoses problem + user performs correct fix + environment reaches target state + +--- + +## 1.8 Key Tau2Bench References + +**Task definitions**: +- Mock domain: `tau2-bench/data/tau2/domains/mock/tasks.json` +- Airline: `tau2-bench/data/tau2/domains/airline/tasks.json` (50 tasks) +- Retail: `tau2-bench/data/tau2/domains/retail/tasks.json` (114 tasks) +- Telecom: `tau2-bench/data/tau2/domains/telecom/tasks.json` (2,285 tasks) + +**Policies**: +- Airline: `tau2-bench/data/tau2/domains/airline/policy.md` +- Retail: `tau2-bench/data/tau2/domains/retail/policy.md` +- Telecom: `tau2-bench/data/tau2/domains/telecom/main_policy.md` + +**Tool implementations**: +- Airline tools: `tau2-bench/src/tau2/domains/airline/tools.py` +- Retail tools: `tau2-bench/src/tau2/domains/retail/tools.py` +- Telecom agent tools: `tau2-bench/src/tau2/domains/telecom/tools.py` +- Telecom user tools: `tau2-bench/src/tau2/domains/telecom/user_tools.py` + +**Evaluation code**: +- Main evaluator: `tau2-bench/src/tau2/evaluator/evaluator.py` +- Metrics (pass^k): `tau2-bench/src/tau2/metrics/agent_metrics.py` +- Orchestrator (runs episodes): `tau2-bench/src/tau2/orchestrator.py` + +**Data models**: +- Task structure: `tau2-bench/src/tau2/data_model/tasks.py` +- Airline models: `tau2-bench/src/tau2/domains/airline/data_model.py` +- Retail models: `tau2-bench/src/tau2/domains/retail/data_model.py` +- Telecom models: `tau2-bench/src/tau2/domains/telecom/data_model.py` + +--- diff --git a/brainstorming_forge_tau/tutorials/2_fundamentals.md b/brainstorming_forge_tau/tutorials/2_fundamentals.md new file mode 100644 index 000000000..fd1b2d4d9 --- /dev/null +++ b/brainstorming_forge_tau/tutorials/2_fundamentals.md @@ -0,0 +1,235 @@ +# Part 2: The Fundamentals + +## 2.1 What is Tool Calling? + +Tool calling allows the LLM to invoke functions instead of just generating text. + +**Example:** +```python +# Without tools: +User: "What's the weather in NYC?" +Model: "I don't have access to real-time weather data." + +# With tools: +User: "What's the weather in NYC?" +Model: get_weather(city="NYC") +Tool: {"temperature": 72, "conditions": "sunny"} +Model: "It's 72°F and sunny in NYC." +``` + +## 2.2 How Tool Calling Works + +**Core concept:** Models are trained to output special formats (tokens or text tags), then we parse them to extract structured tool calls. + +**Two parsing approaches exist in practice:** + +### Token-Based Parsing (vLLM Native) +Some models use **special token IDs** (e.g., token 12971 = `<|python_tag|>`). vLLM can parse these directly: + +```yaml +# vLLM config +enable_auto_tool_choice: true +tool_call_parser: "hermes" # Model-specific: "mistral", "llama", "internlm" +``` + +### Text-Based Parsing (Manual) +Most libraries parse text tags with regex (seen in Tinker, TRL, Verifiers): + +```python +# Example from tinker-cookbook/tinker_cookbook/renderers.py +def parse_response(self, response_tokens): + text = self.tokenizer.decode(response_tokens) + match = re.search(r"(.*?)", text, re.DOTALL) + if match: + return Message(role="assistant", tool_calls=[json.loads(match.group(1))]) + return Message(role="assistant", content=text) +``` + +**Reference:** [Tinker renderers.py](../../tinker-cookbook/tinker_cookbook/renderers.py) + +**NOTE**: Every model has its own format. We shouldn't use arbitrary tags with arbitrary models. + +## 2.3 What is Multi-turn? + +Multi-turn = multiple back-and-forth exchanges in a single episode. + +**Single-turn:** +``` +User: "What's 2+2?" +Model: "4" +[Done] +``` + +**Multi-turn:** +``` +User: "What's 2+2?" +Model: "4" +User: "What's 4+2?" +Model: "6" +User: "What's 6+2?" +Model: "8" +[Done] +``` + +For tool calling, multi-turn enables: +1. Call tool +2. Get result +3. Use result to decide next action +4. Repeat until task complete + +## 2.4 Multi-turn Loop: A Simple Python Example + +```python +# Conceptual multi-turn loop +env = create_env(task="Book a flight to NYC") +messages = [{"role": "user", "content": "Book me a flight to NYC"}] +done = False + +while not done: + # 1. Build prompt from message history + prompt = build_prompt(messages) + + # 2. Generate response + # On first iteration it calls the tool and gets the results + # On following iterations it acts based on the result + # repeat until model says it is done + # Another option is to have another LLM here acting as an user. + response = model.generate(prompt) + + # 3. Check if tool call + if has_tool_call(response): + # Parse and execute tool + tool_call = parse_tool_call(response) + tool_result = env.execute_tool(tool_call) + + # Add to history + messages.append({"role": "assistant", "tool_calls": [tool_call]}) + messages.append({"role": "tool", "content": tool_result}) + else: + # Final answer + messages.append({"role": "assistant", "content": response}) + done = True + +# Get final reward +reward = env.get_reward() +``` + +Key points: +- **Loop** until done +- **Accumulate** messages (conversation history) +- **Tools** execute via environment +- **Reward** computed at end (sparse) + +## 2.5 What is an Environment? + +An **environment** manages: +1. **Tool execution**: Runs tools, returns results +2. **State management**: Tracks what's been done +3. **Reward computation**: Scores the episode + +**Standard API** (gym-like): + +```python +# Initialize +env = Environment(task=task_data) +state = env.reset() # Returns initial state/observation + +# Step +result = env.step(action) # Execute tool or message +# result contains: +# - observation: New state (tool result, env feedback) +# - reward: Immediate reward (often 0.0 for intermediate steps) +# - done: Is episode complete? +# - info: Extra metadata + +# Final reward +if result.done: + final_reward = result.reward +``` + +**Relationship to tools:** +- Environment **owns** the tools +- `env.step(tool_call)` executes the tool +- Returns tool result as observation +- Updates internal state (databases, etc.) + +## 2.6 Message Format (OpenAI Standard) + +Take the example: +``` +"Assistant: I'll search for flights and check the weather for you. +{"name": "search_flights", "arguments": {"destination": "NYC"}} + + +{"name": "get_weather", "arguments": {"city": "NYC"}} +" +``` + +**After parsing, this becomes the structured message** with separate `content` and `tool_calls` fields. Most libraries use OpenAI's chat format: + +```python +messages = [ + # System message (optional) + { + "role": "system", + "content": "You are a helpful assistant with access to tools..." + }, + + # User message + { + "role": "user", + "content": "Book me a flight to NYC and check the weather there" + }, + + # Assistant message (with content AND tool calls in ONE message) + { + "role": "assistant", + "content": "I'll search for flights and check the weather for you.", + "tool_calls": [ + { + "id": "call_123", + "function": { + "name": "search_flights", + "arguments": '{"destination": "NYC"}' + } + }, + { + "id": "call_124", + "function": { + "name": "get_weather", + "arguments": '{"city": "NYC"}' + } + } + ] + }, + + # Tool results (one per tool call) + { + "role": "tool", + "content": '[{"flight": "AA100", "price": "$200"}]', + "tool_call_id": "call_123" + }, + { + "role": "tool", + "content": '{"temperature": 72, "conditions": "sunny"}', + "tool_call_id": "call_124" + } +] +``` + +**Key fields:** +- `role`: "system", "user", "assistant", or "tool" +- `content`: Text content +- `tool_calls`: List of tool invocations (assistant only) +- `tool_call_id`: Links tool result to invocation + +**Chat template** converts messages to model input: +```python +# Using tokenizer +prompt = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False +) +# Returns formatted string ready for model +``` diff --git a/brainstorming_forge_tau/tutorials/3_forge_current_state.md b/brainstorming_forge_tau/tutorials/3_forge_current_state.md new file mode 100644 index 000000000..ba7d3f762 --- /dev/null +++ b/brainstorming_forge_tau/tutorials/3_forge_current_state.md @@ -0,0 +1,271 @@ +# Part 3: How Forge Currently Works + +## 3.1 Current Forge GRPO Flow (GSM8K Example) + +Forge currently implements GRPO (Group Relative Policy Optimization) for single-turn tasks like math problems. + +**Architecture:** +```python +# apps/grpo/main.py + +# 1. Setup services (distributed actors via Monarch) +policy = Generator(...) # vLLM-based generation +trainer = TitanTrainer(...) # Training service +replay_buffer = ReplayBuffer(...) # Store episodes +ref_model = ReferenceModel(...) # Reference for KL +reward_actor = RewardActor(...) # Score responses + +# 2. Rollout loop (continuous_rollouts) +async def continuous_rollouts(): + while True: + # Sample prompt from dataset + sample = await dataloader.sample.call_one() + prompt, target = sample["prompt"], sample["target"] + + # Generate G responses (group) + responses = await policy.generate.route( + prompt, + n=group_size # e.g., 8 responses + ) + + # Score each response + episodes = [] + for response in responses: + episode = Episode(...) + episode.reward = await reward_actor.evaluate_response.route( + prompt=prompt, + response=response.text, + target=target + ) + episodes.append(episode) + + # Get reference logprobs + ref_logprobs = await ref_model.forward.route(...) + + # Compute advantages (group-relative) + advantages = compute_advantages(episodes) + + # Add to replay buffer + for episode in episodes: + await replay_buffer.add.call_one(episode) + +# 3. Training loop (continuous_training) +async def continuous_training(): + while True: + batch = await replay_buffer.sample(batch_size) + + # Train on batch + await trainer.train_step( + inputs=batch["inputs"], + targets=batch["targets"], + advantages=batch["advantages"] + ) + + # Update policy weights + version = await trainer.push_weights() + await policy.update_weights(version) +``` + +**Key features:** +- **Async distributed**: Actors communicate via Monarch +- **Parallel rollouts**: Multiple `continuous_rollouts()` tasks +- **Decoupled**: Rollout and training loops run independently +- **Replay buffer**: Stores episodes for training + +## 3.2 What Forge is Missing for Tool Calling + +**Current GSM8K flow:** +``` +Sample prompt → Generate response → Score → Train +``` + +**Needed for tool calling:** +``` +Sample task → Multi-turn loop → Train + ↓ + Generate → Parse → Execute tool → Update state → Repeat -> Score +``` + +**Missing components:** + +### 1. Multi-turn Loop +**Current**: Single `policy.generate.route(prompt)` +**Needed**: Loop with multiple generation calls + +```python +# Need to add: +while not done: + response = await policy.generate.route(prompt) + if has_tool_call(response): + tool_result = execute_tool(...) + # Continue loop + else: + done = True +``` + +### 2. Tool Call Detection & Parsing +**Current**: No parsing +**Needed**: Extract tool calls from model output + +```python +# Need to add: +def parse_tool_call(response_text): + if "" in response_text: + # Parse JSON + return tool_call + return None +``` + +### 3. Message History Management +**Current**: Single prompt +**Needed**: Accumulate multi-turn conversation + +```python +# Need to add: +messages = [ + {"role": "user", "content": task}, + {"role": "assistant", "tool_calls": [...]}, + {"role": "tool", "content": result}, + # ... more turns +] +``` + +### 4. Tool Execution +**Current**: No tool support +**Needed**: Environment to execute tools + +```python +# Need to add: +env = Environment(task=task) +result = env.step(tool_call) +``` + +### 5. Response Masking +**Current**: Naively split between prompt/answer and train on the answer. This + would train on all tokens, including tool calls. +**Needed**: Mask to ignore tool results in the loss function + +```python +# Need to add: +response_mask = [ + 1, 1, 1, # LLM output - TRAIN + 0, 0, 0, # Tool result - IGNORE + 1, 1, 1, # LLM output - TRAIN +] +``` + +### 6. Episode Structure +**Current** (from `apps/grpo/main.py:44-74`): +```python +@dataclass +class Episode: + episode_id: str + pad_id: int + request_len: int + response_len: int + target: Any | None = None + # Processed data + completion: Completion | None = None # Contains prompt_ids, token_ids, logprobs + ref_logprobs: torch.Tensor | None = None + reward: float | None = None + advantage: float | None = None +``` + +**Multi turn**: + +**References**: +**Tinker** `tinker-cookbook/tinker_cookbook/rl/types.py`, +**VERL** `verl/experimental/agent_loop/tool_agent_loop.py`, +**TRL** `trl/examples/scripts/openenv/catch.py` +**NeMo-RL** `RL/nemo_rl/experience/rollouts.py` + +- Store all turns (transition) in single Episode (trajectory) +- Concatenate turns during rollout or when converting to training data +- Build response_mask to exclude tool results from training + +**Tinker's approach** (`tinker-cookbook/tinker_cookbook/rl/types.py`): +```python +Observation: TypeAlias = tinker.ModelInput + +@dataclass +class Transition: + ob: Observation + ac: TokensWithLogprobs + reward: float + episode_done: bool + metrics: Metrics = field(default_factory=dict) + +@dataclass(frozen=True) +class Trajectory: + transitions: list[Transition] + final_ob: Observation + +@dataclass +class TrajectoryGroup: + trajectories_G: list[Trajectory] + final_rewards_G: list[float] # computed by the EnvGroupBuilder, looking at whole group + metrics_G: list[Metrics] + + def get_total_rewards(self) -> list[float]: + return [ + sum(transition.reward for transition in trajectory.transitions) + final_reward + for trajectory, final_reward in safezip(self.trajectories_G, self.final_rewards_G) + ] +``` + +### 7. Prompt Formatting with Tools +**Current**: Simple prompt. +**Needed**: Our tokenizer jinja template already supports tools, but need to investigate how to use it +and write `format_tool_schemas` + +```python +# Need to add: +system_prompt = f""" +You have access to these tools: + +{format_tool_schemas(tools)} + +Call tools using this format: +{{"name": "tool_name", "args": {{}}}} +""" +``` + +### 8. Reward Computation +**Current** (from `apps/grpo/main.py:385-398`): Immediate reward from `RewardActor` +```python +# For each response in the group +for i, response in enumerate(responses): + episode.reward = await reward_actor.evaluate_response.route( + prompt=prompt, + response=response.text, + target=target + ) + # reward_actor compares response to target immediately +``` + +**Needed for multi-turn**: Sparse reward from environment after episode completes, i.e. the input to the reward calculator is the **full trajectory**. + +```python +for i, response in enumerate(responses): + ... + +# add this +final_reward = sum(previous_rewards_if_any) + env.get_rewards(responses) +# or just: +final_reward = env.get_rewards(responses) +``` + + + + +--- + +**Summary Table:** + +| Component | GSM8K (Current) | Tool Calling (Needed) | +|-----------|----------------|----------------------| +| **Loop** | Single generate | Multi-turn while loop | +| **Tools** | None | Parse & execute | +| **Reward** | Per-response | Sparse at end | +| **Loss** | All tokens | Masked (exclude tool results) | +| **Episode** | Single turn | multi-turn | From 8c874953564cca8b37ffbdbb7c3100128cdf5558 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 13 Nov 2025 09:11:04 -0800 Subject: [PATCH 2/2] add what the loop should look like --- .../tutorials/4_forge_ideal_state.md | 293 ++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 brainstorming_forge_tau/tutorials/4_forge_ideal_state.md diff --git a/brainstorming_forge_tau/tutorials/4_forge_ideal_state.md b/brainstorming_forge_tau/tutorials/4_forge_ideal_state.md new file mode 100644 index 000000000..2117aba01 --- /dev/null +++ b/brainstorming_forge_tau/tutorials/4_forge_ideal_state.md @@ -0,0 +1,293 @@ + +** WORK IN PROGRESS -- NEEDS CHANGES / CLEANUP / DETAILS ** + +# Part 4.0: What a Multi-Turn Tool Calling with Forge + vLLM + OpenEnv would look like + +For tool calling, we extend Forge's GRPO pattern to handle **multi-turn interactions** where: +- One task → multiple LLM generations + tool executions → one Episode +- Episode contains **concatenated tokens** from all turns +- Training and replay buffer logic remains unchanged + +**Key Principle:** Multi-turn only changes the **rollout phase**. Training stays the same. + +--- + +## Setup: Services + Multi-Environment Support + +Notice that an Env in OpenEnv is a **tool execution environment**. It doesn't know about tasks. It only knows about tools. +Other Envs may have more responsabilities, such as holding history conversation and providing the data. + +```python +# 1. Setup services (same as single-turn, plus environments) +policy = Generator(...) +trainer = TitanTrainer(...) +replay_buffer = ReplayBuffer(...) +ref_model = ReferenceModel(...) + +# Dataloader provides tasks (prompts + metadata) +dataloader = DataLoader(Tau2BenchDataset(...)) + +# Task-based routing +# Different environments = different tools, max_turns, rewards +env_map = { + "websearch": WebSearchEnv.from_docker_image("tau2bench/websearch:latest"), + "coding": CodingEnv.from_docker_image("tau2bench/coding:latest"), + "airline": AirlineEnv.from_docker_image("tau2bench/airline:latest"), +} + +# Environment-specific configuration +max_turns_config = { + "websearch": 10, + "coding": 15, + "airline": 8, +} +``` + +**References:** +- Verifiers: `verifiers/envs/env_group.py` +- Tinker: `tinker-cookbook/distillation/datasets.py:45-83` + +--- + +## Rollout Loop: Multi-Turn with Environment Routing + +```python +# 2. Rollout loop (continuous_rollouts with multi-turn) +async def continuous_rollouts(): + while True: + # Sample task from dataloader + task = await dataloader.sample.call_one() + # task.prompt: "Book a flight from SF to NYC on March 15th" + # task.task_type: "websearch" | "coding" | "airline" + # task.metadata: Additional task-specific info + + # Route to correct environment based on task type + env_client = env_map[task.task_type] + max_turns = max_turns_config[task.task_type] + + # Reset environment to get tools (env doesn't know the task) + # Reference: OpenEnv/src/core/http_env_client.py:142-154 + env_state = env_client.reset() + tool_schemas = env_state.observation.tools # Available tools for this env + + # Generate G samples for this task + # TODO: Investigate parallelizing with asyncio.gather() instead of sequential + episodes = [] + for _ in range(group_size): # G samples per task + episode = await play_task( + policy=policy, + task_prompt=task.prompt, # From dataloader + tool_schemas=tool_schemas, # From environment + env=env_client, + max_turns=max_turns + ) + episodes.append(episode) + + # Add to replay buffer (same as single-turn) + for episode in episodes: + await replay_buffer.add.call_one(episode) +``` + +**Critical insight:** Dataset provides tasks, environment provides tools. They are separate. + +--- + +## Multi-Turn Rollout: play_task() + +This replaces the single `policy.generate()` call in single-turn GRPO. + +```python +# Reference: OpenEnv/src/core/client_types.py (StepResult) +from openenv.core.client_types import StepResult +from openenv.core.env_server import ToolCallAction + +async def play_task( + policy: Generator, + task_prompt: str, # From dataloader + tool_schemas: list[dict], # From env.reset() + env: OpenEnvClient, + max_turns: int = 10 +) -> Episode: + """ + Play one task to completion, return single Episode. + + Args: + policy: Generator actor for LLM generation + task_prompt: Task from dataloader (e.g., "Book flight SF->NYC") + tool_schemas: Available tools from env.reset() + env: Environment client for tool execution + max_turns: Maximum conversation turns + + Returns: + Episode with all turns concatenated + """ + + # Initialize conversation with task + # System prompt handled by tokenizer.apply_chat_template() with tools= + # Or dataset can provide task.system_prompt if needed + messages = [{"role": "user", "content": task_prompt}] + + # Storage: concatenate all turns into single sequence + all_tokens = [] + all_logprobs = [] + response_mask = [] # 1=train on LLM output, 0=skip tool results + metadata = {} # Track episode stats + + done = False + turn = 0 + + while not done and turn < max_turns: + # 1. Format prompt with conversation history + tools + # Tokenizer injects system prompt with tool definitions when tools= is passed + prompt = tokenizer.apply_chat_template( + messages, + tools=tool_schemas, # From env.reset() + add_generation_prompt=True, + tokenize=False + ) + + # 2. Generate response + response = await policy.generate.route(prompt, n=1) + + # 3. Parse tool call from response + # Using Tinker pattern: XML tags ... + # Alternative: vLLM native parsing with tool_call_parser="hermes" (see Appendix) + tool_calls = parse_tool_calls(response.text) # Returns list of tool calls + + if tool_calls: + # Tool execution path + # Add assistant message with tool calls + messages.append({ + "role": "assistant", + "content": response.text, + "tool_calls": tool_calls # Structured tool call data + }) + + # Collect LLM output tokens - TRAIN on these + all_tokens.extend(response.token_ids) + all_logprobs.extend(response.logprobs) + response_mask.extend([1] * len(response.token_ids)) + + # Execute tools (parallel if multiple calls) + # TODO: Confirm environment can handle parallel requests + try: + tool_tasks = [ + env.execute_tool(tc["name"], tc["args"]) + for tc in tool_calls + ] + tool_results = await asyncio.gather(*tool_tasks) + except Exception as e: + # Handle tool execution errors + tool_results = [{"content": f"Error: {str(e)}"}] + + # Add tool results to messages and tokens + for tool_result in tool_results: + tool_content = tool_result.content + + # Truncate long tool responses to avoid context overflow + tool_tokens = tokenizer.encode(tool_content, add_special_tokens=False) + tool_tokens = truncate(tool_tokens, max_length=256) + # TODO: Decide where truncate() lives (env vs rollout loop vs utility) + tool_content = tokenizer.decode(tool_tokens) + + # Add tool result to messages + messages.append({ + "role": "tool", + "content": tool_content + }) + + # Collect tool result tokens - DON'T TRAIN on these + all_tokens.extend(tool_tokens) + all_logprobs.extend([0.0] * len(tool_tokens)) + response_mask.extend([0] * len(tool_tokens)) + + # Check if environment signals done + done = tool_results[-1].get("done", False) if tool_results else False + + else: + # Final answer (no tool call) + messages.append({ + "role": "assistant", + "content": response.text + }) + + # Collect final response tokens - TRAIN on these + all_tokens.extend(response.token_ids) + all_logprobs.extend(response.logprobs) + response_mask.extend([1] * len(response.token_ids)) + + done = True + + turn += 1 + + # Populate episode metadata + metadata = { + "num_turns": turn, + "truncated": turn >= max_turns, + # other stats... + } + + # Get final reward from environment + final_reward = env.get_reward(messages) #TODO: confirm messages as input + + # Create Episode + # TODO: this abstraction will have to change. It was created for single-turn. + completion = Completion( + prompt_ids=None, # Not stored (can reconstruct from messages) + token_ids=torch.tensor(all_tokens), + logprobs=torch.tensor(all_logprobs), + text=tokenizer.decode(all_tokens), + generator_version=0 + ) + + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=tokenizer.pad_token_id or tokenizer.eos_token_id, + request_len=0, # Varies per turn, not fixed + response_len=len(all_tokens), + target=None, # Tau2Bench doesn't expose ground truth during training + completion=completion, + response_mask=torch.tensor(response_mask), # NEW: Mask for training + ref_logprobs=None, # Computed later by ref_model + reward=final_reward, + advantage=None, # Computed later with group + metadata=metadata # NEW: Episode statistics + ) + + return episode +``` +## Training Loop + +Stays the same, but we add `response_mask` + +```python +# Reference: apps/grpo/main.py + +# 3. Training loop (minimal changes - just add response_mask) +async def continuous_training(): + while True: + # Sample batch from replay buffer + batch = await replay_buffer.sample(batch_size) + + # Get reference logprobs + ref_logprobs = await ref_model.forward.route( + prompt_ids=batch["prompt_ids"], + response_ids=batch["response_ids"] + ) + + # Compute advantages (group-relative) + advantages = compute_group_advantages(batch["rewards"]) + + # Train on batch with response mask + await trainer.train_step( + inputs=batch["prompt_ids"], + targets=batch["response_ids"], + advantages=advantages, + ref_logprobs=ref_logprobs, + response_mask=batch["response_mask"], # NEW: Mask tool results + ) + + # Update policy weights + version = await trainer.push_weights() + await policy.update_weights(version) +```