zuxin-llm commited on
Commit
0e64a7f
·
verified ·
1 Parent(s): c52e518

Upload xlam_tool_call_parser.py

Browse files
Files changed (1) hide show
  1. xlam_tool_call_parser.py +198 -0
xlam_tool_call_parser.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import Dict, List, Sequence, Union
4
+ import partial_json_parser
5
+ from partial_json_parser.core.options import Allow
6
+
7
+ from vllm.entrypoints.openai.protocol import (
8
+ ChatCompletionRequest, DeltaMessage, DeltaToolCall,
9
+ DeltaFunctionCall, ExtractedToolCallInformation, ToolCall, FunctionCall
10
+ )
11
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
12
+ from vllm.utils import random_uuid
13
+ from vllm.logger import init_logger
14
+ from transformers import PreTrainedTokenizerBase
15
+ from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
16
+ is_complete_json,
17
+ partial_json_loads)
18
+
19
+ logger = init_logger(__name__)
20
+
21
+ @ToolParserManager.register_module("xlam")
22
+ class xLAMToolParser(ToolParser):
23
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
24
+ super().__init__(tokenizer)
25
+ # State for streaming mode
26
+ self.prev_tool_calls: List[Dict] = []
27
+ self.current_tools_sent: List[bool] = []
28
+ self.streamed_args: List[str] = []
29
+ # Remove regex since we're parsing direct JSON
30
+
31
+ def extract_tool_calls(
32
+ self,
33
+ model_output: str,
34
+ request: ChatCompletionRequest
35
+ ) -> ExtractedToolCallInformation:
36
+ try:
37
+ # Modified: Direct JSON parsing without looking for ```
38
+ if not model_output.strip().startswith('['):
39
+ return ExtractedToolCallInformation(
40
+ tools_called=False,
41
+ tool_calls=[],
42
+ content=model_output
43
+ )
44
+
45
+ tool_calls_data = json.loads(model_output)
46
+ tool_calls: List[ToolCall] = []
47
+
48
+ for idx, call in enumerate(tool_calls_data):
49
+ tool_call = ToolCall(
50
+ id=f"call_{idx}_{random_uuid()}",
51
+ type="function",
52
+ function=FunctionCall(
53
+ name=call["name"],
54
+ arguments=json.dumps(call["arguments"])
55
+ )
56
+ )
57
+ tool_calls.append(tool_call)
58
+
59
+ return ExtractedToolCallInformation(
60
+ tools_called=True,
61
+ tool_calls=tool_calls,
62
+ content=None
63
+ )
64
+
65
+ except Exception:
66
+ logger.exception("Error extracting tool calls")
67
+ return ExtractedToolCallInformation(
68
+ tools_called=False,
69
+ tool_calls=[],
70
+ content=model_output
71
+ )
72
+
73
+ def extract_tool_calls_streaming(
74
+ self,
75
+ previous_text: str,
76
+ current_text: str,
77
+ delta_text: str,
78
+ previous_token_ids: Sequence[int],
79
+ current_token_ids: Sequence[int],
80
+ delta_token_ids: Sequence[int],
81
+ request: ChatCompletionRequest,
82
+ ) -> Union[DeltaMessage, None]:
83
+ if not current_text.strip().startswith('['):
84
+ return DeltaMessage(content=delta_text)
85
+
86
+ flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
87
+
88
+ try:
89
+ tool_call_arr = []
90
+ is_complete = []
91
+ try:
92
+ # Parse the JSON array
93
+ start_idx = 0
94
+ while start_idx < len(current_text):
95
+ obj, end_idx = partial_json_loads(current_text[start_idx:], flags)
96
+ is_complete.append(
97
+ is_complete_json(current_text[start_idx:start_idx + end_idx])
98
+ )
99
+ start_idx += end_idx
100
+ tool_call_arr.append(obj)
101
+ except partial_json_parser.core.exceptions.MalformedJSON:
102
+ logger.debug('not enough tokens to parse into JSON yet')
103
+ return None
104
+
105
+ # Get current tool call based on state
106
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
107
+ if len(tool_call_arr) > 0 else {}
108
+
109
+ # Case 1: No tools parsed yet
110
+ if len(tool_call_arr) == 0:
111
+ return None
112
+
113
+ # Case 2: Starting a new tool in array
114
+ elif (len(tool_call_arr) > 0
115
+ and len(tool_call_arr) > self.current_tool_id + 1):
116
+
117
+ # Handle any remaining arguments from previous tool
118
+ if self.current_tool_id >= 0:
119
+ cur_arguments = current_tool_call.get("arguments")
120
+ if cur_arguments:
121
+ cur_args_json = json.dumps(cur_arguments)
122
+ sent = len(self.streamed_args[self.current_tool_id])
123
+ argument_diff = cur_args_json[sent:]
124
+
125
+ if argument_diff:
126
+ delta = DeltaMessage(tool_calls=[
127
+ DeltaToolCall(
128
+ index=self.current_tool_id,
129
+ function=DeltaFunctionCall(
130
+ arguments=argument_diff
131
+ ).model_dump(exclude_none=True)
132
+ )
133
+ ])
134
+ self.streamed_args[self.current_tool_id] += argument_diff
135
+ return delta
136
+
137
+ # Setup new tool
138
+ self.current_tool_id = len(tool_call_arr) - 1
139
+ self.current_tools_sent.append(False)
140
+ self.streamed_args.append("")
141
+ logger.debug("starting new tool %d", self.current_tool_id)
142
+ return None
143
+
144
+ # Case 3: Send tool name if not sent yet
145
+ elif not self.current_tools_sent[self.current_tool_id]:
146
+ function_name = current_tool_call.get("name")
147
+ if function_name:
148
+ delta = DeltaMessage(tool_calls=[
149
+ DeltaToolCall(
150
+ index=self.current_tool_id,
151
+ type="function",
152
+ id=f"call_{self.current_tool_id}_{random_uuid()}",
153
+ function=DeltaFunctionCall(
154
+ name=function_name
155
+ ).model_dump(exclude_none=True)
156
+ )
157
+ ])
158
+ self.current_tools_sent[self.current_tool_id] = True
159
+ return delta
160
+ return None
161
+
162
+ # Case 4: Stream arguments
163
+ else:
164
+ cur_arguments = current_tool_call.get("arguments")
165
+ if cur_arguments:
166
+ sent = len(self.streamed_args[self.current_tool_id])
167
+ cur_args_json = json.dumps(cur_arguments)
168
+ prev_arguments = self.prev_tool_calls[self.current_tool_id].get("arguments")
169
+
170
+ argument_diff = None
171
+ if is_complete[self.current_tool_id]:
172
+ argument_diff = cur_args_json[sent:]
173
+ elif prev_arguments:
174
+ prev_args_json = json.dumps(prev_arguments)
175
+ if cur_args_json != prev_args_json:
176
+ prefix = find_common_prefix(prev_args_json, cur_args_json)
177
+ argument_diff = prefix[sent:]
178
+
179
+ if argument_diff is not None:
180
+ delta = DeltaMessage(tool_calls=[
181
+ DeltaToolCall(
182
+ index=self.current_tool_id,
183
+ function=DeltaFunctionCall(
184
+ arguments=argument_diff
185
+ ).model_dump(exclude_none=True)
186
+ )
187
+ ])
188
+ self.streamed_args[self.current_tool_id] += argument_diff
189
+ return delta
190
+
191
+ self.prev_tool_calls = tool_call_arr
192
+ return None
193
+
194
+ except Exception:
195
+ logger.exception("Error in streaming tool calls")
196
+ logger.debug("Skipping chunk due to streaming error")
197
+ return None
198
+