File size: 4,029 Bytes
1822f54
 
 
 
3ce1088
1822f54
 
3ce1088
1822f54
3ce1088
77196ea
3ce1088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import time
import pandas as pd
import polars as pl
import torch
import logging
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from typing import Optional

logger = logging.getLogger(__name__)

def sts(modelname: str, data1: str, data2: str, score: float) -> Optional[pl.DataFrame]:
    """
    Calculate semantic textual similarity between two sets of sentences.
    
    Args:
        modelname: Name of the model to use
        data1: Path to first input CSV file
        data2: Path to second input CSV file
        score: Minimum similarity score threshold
        
    Returns:
        Optional[pl.DataFrame]: DataFrame with similarity results or None if error occurs
    """
    try:
        st = time.time()
        
        # Initialize model
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Using device: {device}")
        
        model = SentenceTransformer(
            modelname,
            device=device,
            trust_remote_code=True,
        )
        
        # Read and validate input data
        sentences1 = Dataset.from_pandas(pd.read_csv(data1, on_bad_lines='skip', header=0, sep="\t"))
        sentences2 = Dataset.from_pandas(pd.read_csv(data2, on_bad_lines='skip', header=0, sep="\t"))
        
        if sentences1.num_rows == 0 or sentences2.num_rows == 0:
            logger.error("Empty input data found")
            return None
            
        # Generate embeddings
        logger.info("Generating embeddings for first set...")
        embeddings1 = model.encode(
            sentences1["text"],
            normalize_embeddings=True,
            batch_size=1024,
            show_progress_bar=True
        )
        
        logger.info("Generating embeddings for second set...")
        embeddings2 = model.encode(
            sentences2["text"],
            normalize_embeddings=True,
            batch_size=1024,
            show_progress_bar=True
        )
        
        # Calculate similarity matrix
        logger.info("Calculating similarity matrix...")
        similarity_matrix = model.similarity(embeddings1, embeddings2)
        
        # Process results
        df_pd = pd.DataFrame(similarity_matrix)
        dfi = df_pd.__dataframe__()
        df = pl.from_dataframe(dfi)
        
        # Transform matrix to long format
        df_matrix_with_index = df.with_row_index(name="row_index").with_columns(
            pl.col("row_index").cast(pl.UInt64)
        )
        df_long = df_matrix_with_index.unpivot(
            index="row_index",
            variable_name="column_index",
            value_name="score"
        ).with_columns(pl.col("column_index").cast(pl.UInt64))
        
        # Join with original text
        df_sentences1 = pl.DataFrame(sentences1.to_pandas()).with_row_index(name="row_index").with_columns(
            pl.col("row_index").cast(pl.UInt64)
        )
        df_sentences2 = pl.DataFrame(sentences2.to_pandas()).with_row_index(name="column_index").with_columns(
            pl.col("column_index").cast(pl.UInt64)
        )
        
        # Process final results
        df_long = (df_long
                  .with_columns([pl.col("score").round(4).cast(pl.Float32)])
                  .join(df_sentences1, on="row_index")
                  .join(df_sentences2, on="column_index"))
        
        df_long = df_long.rename({
            "text": "sentences1",
            "text_right": "sentences2",
        }).drop(["row_index", "column_index"])
        
        # Filter and sort results
        result_df = df_long.filter(pl.col("score") > score).sort(["score"], descending=True)
        
        elapsed_time = time.time() - st
        logger.info(f'Execution time: {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}')
        logger.info(f'Found {len(result_df)} pairs above score threshold {score}')
        
        return result_df
        
    except Exception as e:
        logger.error(f"Error in STS process: {str(e)}")
        return None