|
import time |
|
import pandas as pd |
|
import polars as pl |
|
import torch |
|
import logging |
|
from datasets import Dataset |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers.util import paraphrase_mining |
|
from typing import Optional |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def mining(modelname: str, path: str, score: float) -> Optional[pl.DataFrame]: |
|
""" |
|
Perform paraphrase mining on the input data. |
|
|
|
Args: |
|
modelname: Name of the model to use |
|
path: Path to the input CSV file |
|
score: Minimum similarity score threshold |
|
|
|
Returns: |
|
Optional[pl.DataFrame]: DataFrame with mining results or None if error occurs |
|
""" |
|
try: |
|
st = time.time() |
|
|
|
|
|
data = Dataset.from_pandas(pd.read_csv(path, on_bad_lines='skip', header=0, sep="\t")) |
|
original_df = pd.read_csv(path, on_bad_lines='skip', header=0, sep="\t") |
|
|
|
if data.num_rows == 0: |
|
logger.error("No data found in input file") |
|
return None |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {device}") |
|
|
|
model = SentenceTransformer( |
|
modelname, |
|
device=device, |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
logger.info("Starting paraphrase mining...") |
|
paraphrases = paraphrase_mining( |
|
model, |
|
data["text"], |
|
corpus_chunk_size=len(data), |
|
show_progress_bar=True, |
|
batch_size=1024, |
|
max_pairs=len(data) ** 2, |
|
) |
|
|
|
|
|
df_pd = pd.DataFrame(paraphrases) |
|
df = pl.from_pandas(df_pd) |
|
df = df.rename({"0": "score", "1": "sentence_1", "2": "sentence_2"}) |
|
|
|
union_df = pl.DataFrame(data.to_pandas()) |
|
original_columns = original_df.columns.tolist() |
|
|
|
|
|
additional_cols = [] |
|
for col in original_columns: |
|
if col != "text": |
|
additional_cols.extend([ |
|
union_df.select(pl.col(col)).to_series()[df["sentence_1"].cast(pl.Int32)].alias(f"{col}_1"), |
|
union_df.select(pl.col(col)).to_series()[df["sentence_2"].cast(pl.Int32)].alias(f"{col}_2") |
|
]) |
|
|
|
|
|
df = df.with_columns([ |
|
pl.col("score").round(3).cast(pl.Float32), |
|
union_df.select(pl.col("text")).to_series()[df["sentence_1"].cast(pl.Int32)].alias("sentence_1"), |
|
union_df.select(pl.col("text")).to_series()[df["sentence_2"].cast(pl.Int32)].alias("sentence_2"), |
|
*additional_cols |
|
]).filter(pl.col("score") > score).sort(["score"], descending=True) |
|
|
|
elapsed_time = time.time() - st |
|
logger.info(f'Execution time: {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}') |
|
logger.info(f'Found {len(df)} paraphrases above score threshold {score}') |
|
|
|
return df |
|
|
|
except Exception as e: |
|
logger.error(f"Error in mining process: {str(e)}") |
|
return None |