|
import torch, chromadb, gc |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
class is_docs: |
|
def __init__(self): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = SentenceTransformer("nlpai-lab/KURE-v1", |
|
cache_folder="/Users/jaewook/PycharmProjects/DS_security_API/weights", |
|
trust_remote_code=True).eval().to(self.device) |
|
|
|
self.client_docs = chromadb.PersistentClient(path="../db/docs") |
|
self.collection_docs = self.client_docs.get_or_create_collection(name="image_embedding", |
|
metadata={"hnsw": "cosine"}, ) |
|
self.cos_sim = torch.nn.CosineSimilarity(dim=0) |
|
|
|
@torch.inference_mode() |
|
async def making_embedding_vector(self, docs: str, category: int = 1, infer_mode: bool = False): |
|
embeddings = self.model.encode(docs).tolist() |
|
test_metadata = {"category": category} |
|
if not infer_mode: |
|
for embedding in embeddings: |
|
self.add_doc_vectors(embedding, test_metadata) |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
return embeddings |
|
|
|
def add_doc_vectors(self, vectors, metadatas): |
|
self.collection_docs.add( |
|
embeddings=vectors, |
|
metadatas=metadatas, |
|
ids="asdf" |
|
) |
|
|
|
|
|
if __name__=="__main__": |
|
import os |
|
print(os.getcwd()) |
|
|
|
|
|
|
|
|