talmago commited on
Commit
5d20fd4
·
verified ·
1 Parent(s): ed7845e

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +13 -286
README.md CHANGED
@@ -11,297 +11,24 @@ tags:
11
 
12
  ## Usage
13
 
14
- ```python
15
- import json
16
- import os
17
- from typing import Iterable, Optional
18
- from huggingface_hub import snapshot_download
19
- import numpy as np
20
- import onnxruntime
21
-
22
- from tokenizers import Tokenizer
23
- from tokenizers.pre_tokenizers import Metaspace
24
-
25
- class CoreferenceResolver:
26
- """
27
- A lightweight coreference resolution model using ONNX runtime.
28
-
29
- This class loads a pre-exported ONNX model and SentencePiece tokenizer
30
- to perform span-based coreference resolution.
31
- """
32
-
33
- def __init__(self, model_dir: str, max_spans: int=512, max_span_width: int = 5):
34
- """
35
- Initialize the coreference resolver.
36
-
37
- Args:
38
- model_dir (str): Path to the directory containing the ONNX model and tokenizer.
39
- max_span_width (int): Maximum width (in tokens) of candidate spans to consider.
40
- """
41
- model_path = os.path.join(model_dir, "model.onnx")
42
- tokenizer_path = os.path.join(model_dir, "tokenizer.json")
43
-
44
- self.session = onnxruntime.InferenceSession(model_path)
45
- self.tokenizer = Tokenizer.from_file(tokenizer_path)
46
- self.tokenizer.pre_tokenizer = Metaspace(replacement="▁", prepend_scheme="always", split=True)
47
- self.max_spans = max_spans
48
- self.max_span_width = max_span_width
49
-
50
- @classmethod
51
- def from_pretrained(cls, model_name_or_path: str):
52
- """
53
- Instantiate a CoreferenceResolver from a model name or local path.
54
-
55
- Args:
56
- model_name_or_path (str): Local path or HuggingFace model repo ID.
57
-
58
- Returns:
59
- CoreferenceResolver: An initialized instance.
60
- """
61
- # Check if local path exists
62
- if os.path.isdir(model_name_or_path):
63
- model_dir = model_name_or_path
64
- else:
65
- model_dir = snapshot_download(repo_id=model_name_or_path)
66
-
67
- # Load configuration
68
- config_path = os.path.join(model_dir, "config.json")
69
- if not os.path.exists(config_path):
70
- raise FileNotFoundError(f"Could not find model configuration in {model_dir}/config.json")
71
-
72
- with open(config_path, "r") as f:
73
- config = json.load(f)
74
-
75
- max_span_width = config.get("max_span_width", 5)
76
- max_spans = config.get("max_spans", 5)
77
-
78
- return cls(
79
- model_dir=model_dir,
80
- max_spans=max_spans,
81
- max_span_width=max_span_width
82
- )
83
-
84
- def __call__(
85
- self,
86
- text: list[str] | list[list[str]],
87
- spans: Optional[list[tuple[int, int]]] = None,
88
- ):
89
- """
90
- Resolve coreference clusters for the input text.
91
-
92
- Args:
93
- text (List[str] | List[List[str]]): A list of tokens or list of list of tokens (sentences).
94
- spans (Optional[List[Tuple[int, int]]]): Predefined spans to consider. If not provided, spans will be enumerated.
95
-
96
- Returns:
97
- dict: {
98
- "clusters": List of resolved clusters (as span index tuples),
99
- "top_spans": Top spans considered by the model,
100
- "antecedent_indices": Index of each span's antecedent candidates,
101
- "predicted_antecedents": Index of selected antecedents (or -1 if none)
102
- }
103
- """
104
- inputs = self._prepare_inputs(text, spans=spans)
105
-
106
- top_spans, antecedent_indices, predicted_antecedents = self.session.run(
107
- None, inputs
108
- )
109
-
110
- clusters = self._agg_clusters(
111
- top_spans, antecedent_indices, predicted_antecedents
112
- )
113
-
114
- return {
115
- "clusters": clusters,
116
- "top_spans": top_spans,
117
- "antecedent_indices": antecedent_indices,
118
- "predicted_antecedents": predicted_antecedents,
119
- }
120
-
121
- def _prepare_inputs(
122
- self,
123
- text: list[str] | list[list[str]],
124
- spans: Optional[list[tuple[int, int]]] = None,
125
- ):
126
- """
127
- Tokenize and format input text and spans into ONNX input format.
128
-
129
- Args:
130
- text (List[str] | List[List[str]]): Input text tokens.
131
- spans (Optional[List[Tuple[int, int]]]): Optional list of spans.
132
-
133
- Returns:
134
- Dict[str, np.ndarray]: Dictionary of input tensors for the ONNX model.
135
- """
136
- if isinstance(text, Iterable) and isinstance(text[0], list):
137
- flat_text = [token for sent in text for token in sent]
138
- else:
139
- flat_text = text
140
-
141
- encoding = self.tokenizer.encode(flat_text, is_pretokenized=True)
142
- input_ids = np.array([encoding.ids], dtype=np.int64)
143
- word_ids = encoding.word_ids
144
-
145
- # Map original words to subword token indices
146
- orig_to_subword = []
147
- current_word = None
148
- start = None
149
- for i, word_id in enumerate(word_ids):
150
- if word_id != current_word:
151
- if current_word is not None:
152
- orig_to_subword.append([start, i - 1])
153
- if word_id is not None:
154
- start = i
155
- current_word = word_id
156
- if current_word is not None and start is not None:
157
- orig_to_subword.append([start, len(word_ids) - 1])
158
-
159
- seq_len = len(orig_to_subword)
160
- offsets = np.array([orig_to_subword], dtype=np.int64) # shape: (1, seq_len, 2)
161
- mask = np.ones((1, seq_len), dtype=bool)
162
- segment_concat_mask = np.ones_like(input_ids, dtype=bool)
163
-
164
- spans = spans or self._enumerate_spans(seq_len)
165
- span_arr = np.array([spans], dtype=np.int64)
166
- spans_tensor = self.pad_or_truncate_spans(span_arr, max_spans=self.max_spans)
167
-
168
- return {
169
- "token_ids": input_ids,
170
- "mask": mask,
171
- "segment_concat_mask": segment_concat_mask,
172
- "offsets": offsets,
173
- "spans": spans_tensor,
174
- }
175
-
176
- def pad_or_truncate_spans(self, spans, max_spans=128):
177
- """
178
- Pad or truncate the span tensor to match a fixed max length.
179
-
180
- Args:
181
- spans (np.ndarray): Span tensor of shape (1, num_spans, 2).
182
- max_spans (int): Desired fixed length.
183
-
184
- Returns:
185
- np.ndarray: Padded/truncated span tensor of shape (1, max_spans, 2).
186
- """
187
- batch_size, num_spans, _ = spans.shape
188
- if num_spans > max_spans:
189
- spans = spans[:, :max_spans, :]
190
- elif num_spans < max_spans:
191
- pad_len = max_spans - num_spans
192
- pad = np.full((batch_size, pad_len, 2), fill_value=-1, dtype=spans.dtype)
193
- spans = np.concatenate([spans, pad], axis=1)
194
- return spans
195
-
196
- def _enumerate_spans(self, seq_len: int) -> list[list[int]]:
197
- """
198
- Generate all possible spans up to max_span_width.
199
-
200
- Args:
201
- seq_len (int): Number of tokens in the sequence.
202
-
203
- Returns:
204
- List[Tuple[int, int]]: Candidate spans.
205
- """
206
- spans = []
207
- for start in range(seq_len):
208
- for end in range(start, min(start + self.max_span_width, seq_len)):
209
- spans.append((start, end))
210
- return spans
211
-
212
- def _agg_clusters(self, top_spans, antecedent_indices, predicted_antecedents):
213
- """
214
- Construct coreference clusters based on top spans and antecedents.
215
-
216
- Args:
217
- top_spans (np.ndarray): Top spans selected by the model.
218
- antecedent_indices (np.ndarray): Candidate antecedents for each span.
219
- predicted_antecedents (np.ndarray): Final predicted antecedents.
220
-
221
- Returns:
222
- List[List[Tuple[int, int]]]: Clustered span indices per document.
223
- """
224
- batch_clusters = []
225
- batch_size = top_spans.shape[0]
226
-
227
- for b in range(batch_size):
228
- spans_to_cluster_id = {}
229
- clusters = []
230
-
231
- for i, predicted_antecedent in enumerate(predicted_antecedents[b]):
232
- if predicted_antecedent < 0:
233
- continue # No antecedent: skip
234
-
235
- predicted_index = int(antecedent_indices[b, i, predicted_antecedent])
236
- antecedent_span = tuple(top_spans[b, predicted_index].tolist())
237
-
238
- if antecedent_span in spans_to_cluster_id:
239
- cluster_id = spans_to_cluster_id[antecedent_span]
240
- else:
241
- cluster_id = len(clusters)
242
- clusters.append([antecedent_span])
243
- spans_to_cluster_id[antecedent_span] = cluster_id
244
-
245
- current_span = tuple(top_spans[b, i].tolist())
246
- clusters[cluster_id].append(current_span)
247
- spans_to_cluster_id[current_span] = cluster_id
248
-
249
- # Sort spans in each cluster by start index
250
- for cluster in clusters:
251
- cluster.sort(key=lambda span: span[0])
252
-
253
- # Remove singleton clusters
254
- clusters = [c for c in clusters if len(c) > 1]
255
- batch_clusters.append(clusters)
256
-
257
- return batch_clusters
258
-
259
-
260
- def decode_clusters(sentences, clusters):
261
- flat_tokens = [token for sent in sentences for token in sent]
262
- decoded_clusters = []
263
- max_index = len(flat_tokens) - 1
264
- for cluster in clusters:
265
- decoded_spans = []
266
- for span in cluster:
267
- start, end = span
268
- if not (0 <= start <= end <= max_index):
269
- decoded_spans.append("<invalid span>")
270
- continue
271
- decoded_spans.append(" ".join(flat_tokens[start:end+1]))
272
- decoded_clusters.append(decoded_spans)
273
- return decoded_clusters
274
 
 
 
275
 
276
- if __name__ == "__main__":
277
- resolver = CoreferenceResolver("models/minillm")
278
 
279
- sentences = [
280
- [
281
- "Barack",
282
- "Obama",
283
- "was",
284
- "the",
285
- "44th",
286
- "President",
287
- "of",
288
- "the",
289
- "United",
290
- "States",
291
- ".",
292
- "He",
293
- "was",
294
- "born",
295
- "in",
296
- "Hawaii",
297
- ".",
298
- ]
299
- ]
300
 
301
- outputs = resolver(sentences)
302
 
303
- print("Clusters:", outputs["clusters"])
304
- print("Decoded clusters:", decode_clusters(sentences, outputs["clusters"][0]))
305
  ```
306
 
307
  Output is:
 
11
 
12
  ## Usage
13
 
14
+ ```sh
15
+ $ pip install coref-onnx
16
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ ```python
19
+ from coref_onnx import CoreferenceResolver, decode_clusters
20
 
21
+ resolver = CoreferenceResolver.from_pretrained("talmago/allennlp-coref-onnx-mMiniLMv2-L12-H384-distilled-from-XLMR-Large")
 
22
 
23
+ sentences = [
24
+ ["Barack", "Obama", "was", "the", "44th", "President", "of", "the", "United", "States", "."],
25
+ ["He", "was", "born", "in", "Hawaii", "."]
26
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ pred = resolver(sentences)
29
 
30
+ print("Clusters:", pred["clusters"][0])
31
+ print("Decoded clusters:", decode_clusters(sentences, pred["clusters"][0]))
32
  ```
33
 
34
  Output is: