Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
@@ -11,297 +11,24 @@ tags:
|
|
11 |
|
12 |
## Usage
|
13 |
|
14 |
-
```
|
15 |
-
|
16 |
-
|
17 |
-
from typing import Iterable, Optional
|
18 |
-
from huggingface_hub import snapshot_download
|
19 |
-
import numpy as np
|
20 |
-
import onnxruntime
|
21 |
-
|
22 |
-
from tokenizers import Tokenizer
|
23 |
-
from tokenizers.pre_tokenizers import Metaspace
|
24 |
-
|
25 |
-
class CoreferenceResolver:
|
26 |
-
"""
|
27 |
-
A lightweight coreference resolution model using ONNX runtime.
|
28 |
-
|
29 |
-
This class loads a pre-exported ONNX model and SentencePiece tokenizer
|
30 |
-
to perform span-based coreference resolution.
|
31 |
-
"""
|
32 |
-
|
33 |
-
def __init__(self, model_dir: str, max_spans: int=512, max_span_width: int = 5):
|
34 |
-
"""
|
35 |
-
Initialize the coreference resolver.
|
36 |
-
|
37 |
-
Args:
|
38 |
-
model_dir (str): Path to the directory containing the ONNX model and tokenizer.
|
39 |
-
max_span_width (int): Maximum width (in tokens) of candidate spans to consider.
|
40 |
-
"""
|
41 |
-
model_path = os.path.join(model_dir, "model.onnx")
|
42 |
-
tokenizer_path = os.path.join(model_dir, "tokenizer.json")
|
43 |
-
|
44 |
-
self.session = onnxruntime.InferenceSession(model_path)
|
45 |
-
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
46 |
-
self.tokenizer.pre_tokenizer = Metaspace(replacement="▁", prepend_scheme="always", split=True)
|
47 |
-
self.max_spans = max_spans
|
48 |
-
self.max_span_width = max_span_width
|
49 |
-
|
50 |
-
@classmethod
|
51 |
-
def from_pretrained(cls, model_name_or_path: str):
|
52 |
-
"""
|
53 |
-
Instantiate a CoreferenceResolver from a model name or local path.
|
54 |
-
|
55 |
-
Args:
|
56 |
-
model_name_or_path (str): Local path or HuggingFace model repo ID.
|
57 |
-
|
58 |
-
Returns:
|
59 |
-
CoreferenceResolver: An initialized instance.
|
60 |
-
"""
|
61 |
-
# Check if local path exists
|
62 |
-
if os.path.isdir(model_name_or_path):
|
63 |
-
model_dir = model_name_or_path
|
64 |
-
else:
|
65 |
-
model_dir = snapshot_download(repo_id=model_name_or_path)
|
66 |
-
|
67 |
-
# Load configuration
|
68 |
-
config_path = os.path.join(model_dir, "config.json")
|
69 |
-
if not os.path.exists(config_path):
|
70 |
-
raise FileNotFoundError(f"Could not find model configuration in {model_dir}/config.json")
|
71 |
-
|
72 |
-
with open(config_path, "r") as f:
|
73 |
-
config = json.load(f)
|
74 |
-
|
75 |
-
max_span_width = config.get("max_span_width", 5)
|
76 |
-
max_spans = config.get("max_spans", 5)
|
77 |
-
|
78 |
-
return cls(
|
79 |
-
model_dir=model_dir,
|
80 |
-
max_spans=max_spans,
|
81 |
-
max_span_width=max_span_width
|
82 |
-
)
|
83 |
-
|
84 |
-
def __call__(
|
85 |
-
self,
|
86 |
-
text: list[str] | list[list[str]],
|
87 |
-
spans: Optional[list[tuple[int, int]]] = None,
|
88 |
-
):
|
89 |
-
"""
|
90 |
-
Resolve coreference clusters for the input text.
|
91 |
-
|
92 |
-
Args:
|
93 |
-
text (List[str] | List[List[str]]): A list of tokens or list of list of tokens (sentences).
|
94 |
-
spans (Optional[List[Tuple[int, int]]]): Predefined spans to consider. If not provided, spans will be enumerated.
|
95 |
-
|
96 |
-
Returns:
|
97 |
-
dict: {
|
98 |
-
"clusters": List of resolved clusters (as span index tuples),
|
99 |
-
"top_spans": Top spans considered by the model,
|
100 |
-
"antecedent_indices": Index of each span's antecedent candidates,
|
101 |
-
"predicted_antecedents": Index of selected antecedents (or -1 if none)
|
102 |
-
}
|
103 |
-
"""
|
104 |
-
inputs = self._prepare_inputs(text, spans=spans)
|
105 |
-
|
106 |
-
top_spans, antecedent_indices, predicted_antecedents = self.session.run(
|
107 |
-
None, inputs
|
108 |
-
)
|
109 |
-
|
110 |
-
clusters = self._agg_clusters(
|
111 |
-
top_spans, antecedent_indices, predicted_antecedents
|
112 |
-
)
|
113 |
-
|
114 |
-
return {
|
115 |
-
"clusters": clusters,
|
116 |
-
"top_spans": top_spans,
|
117 |
-
"antecedent_indices": antecedent_indices,
|
118 |
-
"predicted_antecedents": predicted_antecedents,
|
119 |
-
}
|
120 |
-
|
121 |
-
def _prepare_inputs(
|
122 |
-
self,
|
123 |
-
text: list[str] | list[list[str]],
|
124 |
-
spans: Optional[list[tuple[int, int]]] = None,
|
125 |
-
):
|
126 |
-
"""
|
127 |
-
Tokenize and format input text and spans into ONNX input format.
|
128 |
-
|
129 |
-
Args:
|
130 |
-
text (List[str] | List[List[str]]): Input text tokens.
|
131 |
-
spans (Optional[List[Tuple[int, int]]]): Optional list of spans.
|
132 |
-
|
133 |
-
Returns:
|
134 |
-
Dict[str, np.ndarray]: Dictionary of input tensors for the ONNX model.
|
135 |
-
"""
|
136 |
-
if isinstance(text, Iterable) and isinstance(text[0], list):
|
137 |
-
flat_text = [token for sent in text for token in sent]
|
138 |
-
else:
|
139 |
-
flat_text = text
|
140 |
-
|
141 |
-
encoding = self.tokenizer.encode(flat_text, is_pretokenized=True)
|
142 |
-
input_ids = np.array([encoding.ids], dtype=np.int64)
|
143 |
-
word_ids = encoding.word_ids
|
144 |
-
|
145 |
-
# Map original words to subword token indices
|
146 |
-
orig_to_subword = []
|
147 |
-
current_word = None
|
148 |
-
start = None
|
149 |
-
for i, word_id in enumerate(word_ids):
|
150 |
-
if word_id != current_word:
|
151 |
-
if current_word is not None:
|
152 |
-
orig_to_subword.append([start, i - 1])
|
153 |
-
if word_id is not None:
|
154 |
-
start = i
|
155 |
-
current_word = word_id
|
156 |
-
if current_word is not None and start is not None:
|
157 |
-
orig_to_subword.append([start, len(word_ids) - 1])
|
158 |
-
|
159 |
-
seq_len = len(orig_to_subword)
|
160 |
-
offsets = np.array([orig_to_subword], dtype=np.int64) # shape: (1, seq_len, 2)
|
161 |
-
mask = np.ones((1, seq_len), dtype=bool)
|
162 |
-
segment_concat_mask = np.ones_like(input_ids, dtype=bool)
|
163 |
-
|
164 |
-
spans = spans or self._enumerate_spans(seq_len)
|
165 |
-
span_arr = np.array([spans], dtype=np.int64)
|
166 |
-
spans_tensor = self.pad_or_truncate_spans(span_arr, max_spans=self.max_spans)
|
167 |
-
|
168 |
-
return {
|
169 |
-
"token_ids": input_ids,
|
170 |
-
"mask": mask,
|
171 |
-
"segment_concat_mask": segment_concat_mask,
|
172 |
-
"offsets": offsets,
|
173 |
-
"spans": spans_tensor,
|
174 |
-
}
|
175 |
-
|
176 |
-
def pad_or_truncate_spans(self, spans, max_spans=128):
|
177 |
-
"""
|
178 |
-
Pad or truncate the span tensor to match a fixed max length.
|
179 |
-
|
180 |
-
Args:
|
181 |
-
spans (np.ndarray): Span tensor of shape (1, num_spans, 2).
|
182 |
-
max_spans (int): Desired fixed length.
|
183 |
-
|
184 |
-
Returns:
|
185 |
-
np.ndarray: Padded/truncated span tensor of shape (1, max_spans, 2).
|
186 |
-
"""
|
187 |
-
batch_size, num_spans, _ = spans.shape
|
188 |
-
if num_spans > max_spans:
|
189 |
-
spans = spans[:, :max_spans, :]
|
190 |
-
elif num_spans < max_spans:
|
191 |
-
pad_len = max_spans - num_spans
|
192 |
-
pad = np.full((batch_size, pad_len, 2), fill_value=-1, dtype=spans.dtype)
|
193 |
-
spans = np.concatenate([spans, pad], axis=1)
|
194 |
-
return spans
|
195 |
-
|
196 |
-
def _enumerate_spans(self, seq_len: int) -> list[list[int]]:
|
197 |
-
"""
|
198 |
-
Generate all possible spans up to max_span_width.
|
199 |
-
|
200 |
-
Args:
|
201 |
-
seq_len (int): Number of tokens in the sequence.
|
202 |
-
|
203 |
-
Returns:
|
204 |
-
List[Tuple[int, int]]: Candidate spans.
|
205 |
-
"""
|
206 |
-
spans = []
|
207 |
-
for start in range(seq_len):
|
208 |
-
for end in range(start, min(start + self.max_span_width, seq_len)):
|
209 |
-
spans.append((start, end))
|
210 |
-
return spans
|
211 |
-
|
212 |
-
def _agg_clusters(self, top_spans, antecedent_indices, predicted_antecedents):
|
213 |
-
"""
|
214 |
-
Construct coreference clusters based on top spans and antecedents.
|
215 |
-
|
216 |
-
Args:
|
217 |
-
top_spans (np.ndarray): Top spans selected by the model.
|
218 |
-
antecedent_indices (np.ndarray): Candidate antecedents for each span.
|
219 |
-
predicted_antecedents (np.ndarray): Final predicted antecedents.
|
220 |
-
|
221 |
-
Returns:
|
222 |
-
List[List[Tuple[int, int]]]: Clustered span indices per document.
|
223 |
-
"""
|
224 |
-
batch_clusters = []
|
225 |
-
batch_size = top_spans.shape[0]
|
226 |
-
|
227 |
-
for b in range(batch_size):
|
228 |
-
spans_to_cluster_id = {}
|
229 |
-
clusters = []
|
230 |
-
|
231 |
-
for i, predicted_antecedent in enumerate(predicted_antecedents[b]):
|
232 |
-
if predicted_antecedent < 0:
|
233 |
-
continue # No antecedent: skip
|
234 |
-
|
235 |
-
predicted_index = int(antecedent_indices[b, i, predicted_antecedent])
|
236 |
-
antecedent_span = tuple(top_spans[b, predicted_index].tolist())
|
237 |
-
|
238 |
-
if antecedent_span in spans_to_cluster_id:
|
239 |
-
cluster_id = spans_to_cluster_id[antecedent_span]
|
240 |
-
else:
|
241 |
-
cluster_id = len(clusters)
|
242 |
-
clusters.append([antecedent_span])
|
243 |
-
spans_to_cluster_id[antecedent_span] = cluster_id
|
244 |
-
|
245 |
-
current_span = tuple(top_spans[b, i].tolist())
|
246 |
-
clusters[cluster_id].append(current_span)
|
247 |
-
spans_to_cluster_id[current_span] = cluster_id
|
248 |
-
|
249 |
-
# Sort spans in each cluster by start index
|
250 |
-
for cluster in clusters:
|
251 |
-
cluster.sort(key=lambda span: span[0])
|
252 |
-
|
253 |
-
# Remove singleton clusters
|
254 |
-
clusters = [c for c in clusters if len(c) > 1]
|
255 |
-
batch_clusters.append(clusters)
|
256 |
-
|
257 |
-
return batch_clusters
|
258 |
-
|
259 |
-
|
260 |
-
def decode_clusters(sentences, clusters):
|
261 |
-
flat_tokens = [token for sent in sentences for token in sent]
|
262 |
-
decoded_clusters = []
|
263 |
-
max_index = len(flat_tokens) - 1
|
264 |
-
for cluster in clusters:
|
265 |
-
decoded_spans = []
|
266 |
-
for span in cluster:
|
267 |
-
start, end = span
|
268 |
-
if not (0 <= start <= end <= max_index):
|
269 |
-
decoded_spans.append("<invalid span>")
|
270 |
-
continue
|
271 |
-
decoded_spans.append(" ".join(flat_tokens[start:end+1]))
|
272 |
-
decoded_clusters.append(decoded_spans)
|
273 |
-
return decoded_clusters
|
274 |
|
|
|
|
|
275 |
|
276 |
-
|
277 |
-
resolver = CoreferenceResolver("models/minillm")
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
"was",
|
284 |
-
"the",
|
285 |
-
"44th",
|
286 |
-
"President",
|
287 |
-
"of",
|
288 |
-
"the",
|
289 |
-
"United",
|
290 |
-
"States",
|
291 |
-
".",
|
292 |
-
"He",
|
293 |
-
"was",
|
294 |
-
"born",
|
295 |
-
"in",
|
296 |
-
"Hawaii",
|
297 |
-
".",
|
298 |
-
]
|
299 |
-
]
|
300 |
|
301 |
-
|
302 |
|
303 |
-
|
304 |
-
|
305 |
```
|
306 |
|
307 |
Output is:
|
|
|
11 |
|
12 |
## Usage
|
13 |
|
14 |
+
```sh
|
15 |
+
$ pip install coref-onnx
|
16 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
```python
|
19 |
+
from coref_onnx import CoreferenceResolver, decode_clusters
|
20 |
|
21 |
+
resolver = CoreferenceResolver.from_pretrained("talmago/allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large")
|
|
|
22 |
|
23 |
+
sentences = [
|
24 |
+
["Barack", "Obama", "was", "the", "44th", "President", "of", "the", "United", "States", "."],
|
25 |
+
["He", "was", "born", "in", "Hawaii", "."]
|
26 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
pred = resolver(sentences)
|
29 |
|
30 |
+
print("Clusters:", pred["clusters"][0])
|
31 |
+
print("Decoded clusters:", decode_clusters(sentences, pred["clusters"][0]))
|
32 |
```
|
33 |
|
34 |
Output is:
|