KURE-v1 / doc_embedding.py
Jake-Network's picture
Upload 12 files
40061d7 verified
import torch, chromadb, gc
from sentence_transformers import SentenceTransformer
class is_docs:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer("nlpai-lab/KURE-v1",
cache_folder="/Users/jaewook/PycharmProjects/DS_security_API/weights",
trust_remote_code=True).eval().to(self.device)
self.client_docs = chromadb.PersistentClient(path="../db/docs")
self.collection_docs = self.client_docs.get_or_create_collection(name="image_embedding",
metadata={"hnsw": "cosine"}, )
self.cos_sim = torch.nn.CosineSimilarity(dim=0)
@torch.inference_mode()
async def making_embedding_vector(self, docs: str, category: int = 1, infer_mode: bool = False):
embeddings = self.model.encode(docs).tolist()
test_metadata = {"category": category}
if not infer_mode:
for embedding in embeddings:
self.add_doc_vectors(embedding, test_metadata)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return embeddings
def add_doc_vectors(self, vectors, metadatas):
self.collection_docs.add(
embeddings=vectors,
metadatas=metadatas,
ids="asdf" # 고유 ID
)
if __name__=="__main__":
import os
print(os.getcwd())
# model = SentenceTransformer("nlpai-lab/KURE-v1",
# cache_folder="/Users/jaewook/PycharmProjects/DS_security_API/weights",
# trust_remote_code=True).eval()
# model.save_pretrained('./')