|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
base_model: answerdotai/ModernBERT-base |
|
model-index: |
|
- name: x2bee/KoModernBERT-base-mlm |
|
results: [] |
|
language: |
|
- ko |
|
--- |
|
|
|
# KoModernBERT-base-v02 |
|
|
|
This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) <br> |
|
|
|
* Flash-Attention 2 |
|
* StabelAdamW |
|
* Unpadding & Sequence Packing |
|
|
|
## Example Use |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
from huggingface_hub import HfApi, login |
|
with open('./api_key/HGF_TOKEN.txt', 'r') as hgf: |
|
login(token=hgf.read()) |
|
api = HfApi() |
|
|
|
model_id = "x2bee/KoModernBERT-base-mlm-v01" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForMaskedLM.from_pretrained(model_id).to("cuda") |
|
|
|
def modern_bert_convert_with_multiple_masks(text: str, top_k: int = 1, select_method:str = "Logit") -> str: |
|
if "[MASK]" not in text: |
|
raise ValueError("MLM Model should include '[MASK]' in the sentence") |
|
|
|
while "[MASK]" in text: |
|
inputs = tokenizer(text, return_tensors="pt").to("cuda") |
|
outputs = model(**inputs) |
|
|
|
input_ids = inputs["input_ids"][0].tolist() |
|
mask_indices = [i for i, token_id in enumerate(input_ids) if token_id == tokenizer.mask_token_id] |
|
|
|
current_mask_index = mask_indices[0] |
|
|
|
logits = outputs.logits[0, current_mask_index] |
|
|
|
top_k_tokens = logits.topk(top_k).indices.tolist() |
|
top_k_logits, top_k_indices = logits.topk(top_k) |
|
|
|
if select_method == "Logit": |
|
probabilities = torch.softmax(top_k_logits, dim=0).tolist() |
|
predicted_token_id = random.choices(top_k_indices.tolist(), weights=probabilities, k=1)[0] |
|
predicted_token = tokenizer.decode([predicted_token_id]).strip() |
|
|
|
elif select_method == "Random": |
|
predicted_token_id = random.choice(top_k_tokens) |
|
predicted_token = tokenizer.decode([predicted_token_id]).strip() |
|
|
|
elif select_method == "Best": |
|
predicted_token_id = top_k_tokens[0] |
|
predicted_token = tokenizer.decode([predicted_token_id]).strip() |
|
|
|
else: |
|
raise ValueError("select_method should be one of ['Logit', 'Random', 'Best']") |
|
|
|
text = text.replace("[MASK]", predicted_token, 1) |
|
|
|
print(f"Predicted: {predicted_token} | Current text: {text}") |
|
|
|
return text |
|
``` |
|
|
|
``` |
|
text = "30์ผ ์ ๋จ ๋ฌด์๊ตญ์ [MASK] ํ์ฃผ๋ก์ ์ ๋ ๋ฐ์ํ ์ ์ฃผํญ๊ณต [MASK] ๋น์ ๊ธฐ์ฒด๊ฐ [MASK]์ฐฉ๋ฅํ๋ฉด์ ๊ฐํ ๋ง์ฐฐ๋ก ์๊ธด ํ์ ์ด ๋จ์ ์๋ค. ์ด ์ฐธ์ฌ๋ก [MASK]๊ณผ ์น๋ฌด์ 181๋ช
์ค 179๋ช
์ด ์จ์ง๊ณ [MASK]๋ ํ์ฒด๋ฅผ ์์๋ณผ ์ ์์ด [MASK]๋๋ค. [MASK] ๊ท๋ชจ์ [MASK] ์์ธ ๋ฑ์ ๋ํด ๋ค์ํ [MASK]์ด ์ ๊ธฐ๋๊ณ ์๋ ๊ฐ์ด๋ฐ [MASK]์ ์ค์น๋ [MASK](์ฐฉ๋ฅ ์ ๋ ์์ ์์ค)๊ฐ [MASK]๋ฅผ ํค์ ๋ค๋ [MASK]์ด ๋์ค๊ณ ์๋ค." |
|
result = mbm.modern_bert_convert_with_multiple_masks(text, top_k=1) |
|
|
|
'30์ผ ์ ๋จ ๋ฌด์๊ตญ์ ํฐ๋ฏธ๋ ํ์ฃผ๋ก์ ์ ๋ ๋ฐ์ํ ์ ์ฃผํญ๊ณต ์ฌ๊ณ ๋น์ ๊ธฐ์ฒด๊ฐ ๋ฌด๋จ์ฐฉ๋ฅํ๋ฉด์ ๊ฐํ ๋ง์ฐฐ๋ก ์๊ธด ํ์ ์ด ๋จ์ ์๋ค. ์ด ์ฐธ์ฌ๋ก ์น๊ฐ๊ณผ ์น๋ฌด์ 181๋ช
์ค 179๋ช
์ด ์จ์ง๊ณ ์ผ๋ถ๋ ํ์ฒด๋ฅผ ์์๋ณผ ์ ์์ด ์ค์ข
๋๋ค. ์ฌ๊ณ ๊ท๋ชจ์ ์ฌ๊ณ ์์ธ ๋ฑ์ ๋ํด ๋ค์ํ ์ํน์ด ์ ๊ธฐ๋๊ณ ์๋ ๊ฐ์ด๋ฐ ๊ธฐ๋ด์ ์ค์น๋ ESC(์ฐฉ๋ฅ ์ ๋ ์์ ์์ค)๊ฐ ์ฌ๊ณ ๋ฅผ ํค์ ๋ค๋ ์ฃผ์ฅ์ด ๋์ค๊ณ ์๋ค.' |
|
``` |
|
|
|
``` |
|
text = "์ค๊ตญ์ ์๋๋ [MASK]์ด๋ค" |
|
result = mbm.modern_bert_convert_with_multiple_masks(text, top_k=1) |
|
'์ค๊ตญ์ ์๋๋ ๋ฒ ์ด์ง์ด๋ค' |
|
|
|
text = "์ผ๋ณธ์ ์๋๋ [MASK]์ด๋ค" |
|
result = mbm.modern_bert_convert_with_multiple_masks(text, top_k=1) |
|
'์ผ๋ณธ์ ์๋๋ ๋์ฟ์ด๋ค' |
|
|
|
text = "๋ํ๋ฏผ๊ตญ์ ๊ฐ์ฅ ํฐ ๋์๋ [MASK]์ด๋ค" |
|
result = mbm.modern_bert_convert_with_multiple_masks(text, top_k=1) |
|
'๋ํ๋ฏผ๊ตญ์ ๊ฐ์ฅ ํฐ ๋์๋ ์์ธ์ด๋ค' |
|
``` |
|
|
|
### Framework versions |
|
|
|
- Transformers 4.48.0 |
|
- Pytorch 2.5.1+cu124 |
|
- Datasets 3.2.0 |
|
- Tokenizers 0.21.0 |