File size: 3,255 Bytes
1822f54 3ce1088 1822f54 3ce1088 1822f54 3ce1088 222cf81 3ce1088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import time
import pandas as pd
import polars as pl
import torch
import logging
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import paraphrase_mining
from typing import Optional
logger = logging.getLogger(__name__)
def mining(modelname: str, path: str, score: float) -> Optional[pl.DataFrame]:
"""
Perform paraphrase mining on the input data.
Args:
modelname: Name of the model to use
path: Path to the input CSV file
score: Minimum similarity score threshold
Returns:
Optional[pl.DataFrame]: DataFrame with mining results or None if error occurs
"""
try:
st = time.time()
# Read and validate input data
data = Dataset.from_pandas(pd.read_csv(path, on_bad_lines='skip', header=0, sep="\t"))
original_df = pd.read_csv(path, on_bad_lines='skip', header=0, sep="\t")
if data.num_rows == 0:
logger.error("No data found in input file")
return None
# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
model = SentenceTransformer(
modelname,
device=device,
trust_remote_code=True,
)
# Perform paraphrase mining
logger.info("Starting paraphrase mining...")
paraphrases = paraphrase_mining(
model,
data["text"],
corpus_chunk_size=len(data),
show_progress_bar=True,
batch_size=1024,
max_pairs=len(data) ** 2,
)
# Process results
df_pd = pd.DataFrame(paraphrases)
df = pl.from_pandas(df_pd)
df = df.rename({"0": "score", "1": "sentence_1", "2": "sentence_2"})
union_df = pl.DataFrame(data.to_pandas())
original_columns = original_df.columns.tolist()
# Add additional columns if present
additional_cols = []
for col in original_columns:
if col != "text":
additional_cols.extend([
union_df.select(pl.col(col)).to_series()[df["sentence_1"].cast(pl.Int32)].alias(f"{col}_1"),
union_df.select(pl.col(col)).to_series()[df["sentence_2"].cast(pl.Int32)].alias(f"{col}_2")
])
# Process final results
df = df.with_columns([
pl.col("score").round(3).cast(pl.Float32),
union_df.select(pl.col("text")).to_series()[df["sentence_1"].cast(pl.Int32)].alias("sentence_1"),
union_df.select(pl.col("text")).to_series()[df["sentence_2"].cast(pl.Int32)].alias("sentence_2"),
*additional_cols
]).filter(pl.col("score") > score).sort(["score"], descending=True)
elapsed_time = time.time() - st
logger.info(f'Execution time: {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}')
logger.info(f'Found {len(df)} paraphrases above score threshold {score}')
return df
except Exception as e:
logger.error(f"Error in mining process: {str(e)}")
return None |