s4um1l commited on
Commit
ba10a58
·
1 Parent(s): 4b9a663

introducing parallel processing to make chunking and embedding quicker

Browse files
aimakerspace/text_utils.py CHANGED
@@ -1,6 +1,13 @@
1
  import os
2
  from typing import List
3
  import PyPDF2
 
 
 
 
 
 
 
4
 
5
 
6
  class TextFileLoader:
@@ -42,6 +49,7 @@ class CharacterTextSplitter:
42
  self,
43
  chunk_size: int = 1000,
44
  chunk_overlap: int = 200,
 
45
  ):
46
  assert (
47
  chunk_size > chunk_overlap
@@ -49,6 +57,7 @@ class CharacterTextSplitter:
49
 
50
  self.chunk_size = chunk_size
51
  self.chunk_overlap = chunk_overlap
 
52
 
53
  def split(self, text: str) -> List[str]:
54
  chunks = []
@@ -57,9 +66,29 @@ class CharacterTextSplitter:
57
  return chunks
58
 
59
  def split_texts(self, texts: List[str]) -> List[str]:
 
60
  chunks = []
61
- for text in texts:
62
- chunks.extend(self.split(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return chunks
64
 
65
 
 
1
  import os
2
  from typing import List
3
  import PyPDF2
4
+ import concurrent.futures
5
+ import logging
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO,
9
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
10
+ logger = logging.getLogger(__name__)
11
 
12
 
13
  class TextFileLoader:
 
49
  self,
50
  chunk_size: int = 1000,
51
  chunk_overlap: int = 200,
52
+ max_workers: int = 4
53
  ):
54
  assert (
55
  chunk_size > chunk_overlap
 
57
 
58
  self.chunk_size = chunk_size
59
  self.chunk_overlap = chunk_overlap
60
+ self.max_workers = max_workers
61
 
62
  def split(self, text: str) -> List[str]:
63
  chunks = []
 
66
  return chunks
67
 
68
  def split_texts(self, texts: List[str]) -> List[str]:
69
+ logger.info(f"Splitting {len(texts)} texts in parallel with {self.max_workers} workers")
70
  chunks = []
71
+
72
+ # Use parallel processing if there are multiple texts or large single text
73
+ if len(texts) > 1 or (len(texts) == 1 and len(texts[0]) > 50000):
74
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
75
+ # Map the split function to the list of texts
76
+ future_to_text = {executor.submit(self.split, text): text for text in texts}
77
+
78
+ # Collect results as they complete
79
+ for future in concurrent.futures.as_completed(future_to_text):
80
+ try:
81
+ text_chunks = future.result()
82
+ chunks.extend(text_chunks)
83
+ logger.info(f"Processed text chunk batch: {len(text_chunks)} chunks")
84
+ except Exception as e:
85
+ logger.error(f"Error processing text chunk: {str(e)}")
86
+ else:
87
+ # For small amounts of text, process sequentially
88
+ for text in texts:
89
+ chunks.extend(self.split(text))
90
+
91
+ logger.info(f"Completed splitting texts into {len(chunks)} chunks")
92
  return chunks
93
 
94
 
aimakerspace/vectordatabase.py CHANGED
@@ -1,8 +1,16 @@
1
  import numpy as np
2
  from collections import defaultdict
3
- from typing import List, Tuple, Callable
4
  from aimakerspace.openai_utils.embedding import EmbeddingModel
5
  import asyncio
 
 
 
 
 
 
 
 
6
 
7
 
8
  def cosine_similarity(vector_a: np.array, vector_b: np.array) -> float:
@@ -14,9 +22,10 @@ def cosine_similarity(vector_a: np.array, vector_b: np.array) -> float:
14
 
15
 
16
  class VectorDatabase:
17
- def __init__(self, embedding_model: EmbeddingModel = None):
18
  self.vectors = defaultdict(np.array)
19
  self.embedding_model = embedding_model or EmbeddingModel()
 
20
 
21
  def insert(self, key: str, vector: np.array) -> None:
22
  self.vectors[key] = vector
@@ -48,9 +57,43 @@ class VectorDatabase:
48
  return self.vectors.get(key, None)
49
 
50
  async def abuild_from_list(self, list_of_text: List[str]) -> "VectorDatabase":
51
- embeddings = await self.embedding_model.async_get_embeddings(list_of_text)
52
- for text, embedding in zip(list_of_text, embeddings):
53
- self.insert(text, np.array(embedding))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return self
55
 
56
 
 
1
  import numpy as np
2
  from collections import defaultdict
3
+ from typing import List, Tuple, Callable, Dict
4
  from aimakerspace.openai_utils.embedding import EmbeddingModel
5
  import asyncio
6
+ import logging
7
+ import concurrent.futures
8
+ import time
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO,
12
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
 
15
 
16
  def cosine_similarity(vector_a: np.array, vector_b: np.array) -> float:
 
22
 
23
 
24
  class VectorDatabase:
25
+ def __init__(self, embedding_model: EmbeddingModel = None, batch_size: int = 25):
26
  self.vectors = defaultdict(np.array)
27
  self.embedding_model = embedding_model or EmbeddingModel()
28
+ self.batch_size = batch_size # Process embeddings in batches for better performance
29
 
30
  def insert(self, key: str, vector: np.array) -> None:
31
  self.vectors[key] = vector
 
57
  return self.vectors.get(key, None)
58
 
59
  async def abuild_from_list(self, list_of_text: List[str]) -> "VectorDatabase":
60
+ start_time = time.time()
61
+
62
+ if not list_of_text:
63
+ logger.warning("Empty list provided to build vector database")
64
+ return self
65
+
66
+ logger.info(f"Building embeddings for {len(list_of_text)} text chunks in batches of {self.batch_size}")
67
+
68
+ # Process in batches to avoid overwhelming the API
69
+ batches = [list_of_text[i:i + self.batch_size] for i in range(0, len(list_of_text), self.batch_size)]
70
+ logger.info(f"Split into {len(batches)} batches")
71
+
72
+ for i, batch in enumerate(batches):
73
+ batch_start = time.time()
74
+ logger.info(f"Processing batch {i+1}/{len(batches)} with {len(batch)} text chunks")
75
+
76
+ try:
77
+ # Get embeddings for this batch
78
+ embeddings = await self.embedding_model.async_get_embeddings(batch)
79
+
80
+ # Insert into vector database
81
+ for text, embedding in zip(batch, embeddings):
82
+ self.insert(text, np.array(embedding))
83
+
84
+ batch_duration = time.time() - batch_start
85
+ logger.info(f"Batch {i+1} completed in {batch_duration:.2f}s")
86
+
87
+ # Small delay between batches to avoid rate limiting
88
+ if i < len(batches) - 1:
89
+ await asyncio.sleep(0.5)
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error processing batch {i+1}: {str(e)}")
93
+ # Continue with next batch even if this one failed
94
+
95
+ total_duration = time.time() - start_time
96
+ logger.info(f"Vector database built with {len(self.vectors)} vectors in {total_duration:.2f}s")
97
  return self
98
 
99
 
backend/rag.py CHANGED
@@ -92,7 +92,7 @@ class RetrievalAugmentedQAPipeline:
92
  }
93
 
94
  def process_file(file_path: str, file_name: str) -> List[str]:
95
- """Process an uploaded file and convert it to text chunks"""
96
  logger.info(f"Processing file: {file_name} at path: {file_path}")
97
 
98
  try:
@@ -117,10 +117,20 @@ def process_file(file_path: str, file_name: str) -> List[str]:
117
  logger.warning("No document content loaded")
118
  return ["No content found in the document"]
119
 
120
- # Split text into chunks
121
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
 
 
 
122
  text_chunks = text_splitter.split_texts(documents)
123
 
 
 
 
 
 
 
124
  logger.info(f"Split document into {len(text_chunks)} chunks")
125
  return text_chunks
126
 
@@ -130,23 +140,50 @@ def process_file(file_path: str, file_name: str) -> List[str]:
130
  return [f"Error processing file: {str(e)}"]
131
 
132
  async def setup_vector_db(texts: List[str]) -> VectorDatabase:
133
- """Create vector database from text chunks"""
134
  logger.info(f"Setting up vector database with {len(texts)} text chunks")
135
 
 
136
  embedding_model = EmbeddingModel()
137
- vector_db = VectorDatabase(embedding_model=embedding_model)
 
138
 
139
  try:
 
 
 
 
 
 
 
 
140
  await vector_db.abuild_from_list(texts)
141
 
 
142
  vector_db.documents = texts
143
 
144
  logger.info(f"Vector database built with {len(texts)} documents")
145
  return vector_db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  except Exception as e:
147
  logger.error(f"Error setting up vector database: {str(e)}")
148
  logger.error(traceback.format_exc())
149
 
 
150
  fallback_db = VectorDatabase(embedding_model=embedding_model)
151
  error_text = "I'm sorry, but there was an error processing the document."
152
  fallback_db.insert(error_text, [0.0] * 1536)
 
92
  }
93
 
94
  def process_file(file_path: str, file_name: str) -> List[str]:
95
+ """Process an uploaded file and convert it to text chunks - optimized for speed"""
96
  logger.info(f"Processing file: {file_name} at path: {file_path}")
97
 
98
  try:
 
117
  logger.warning("No document content loaded")
118
  return ["No content found in the document"]
119
 
120
+ # Split text into chunks - use parallel processing
121
+ logger.info("Splitting document with parallel processing")
122
+ chunk_size = 1500 # Increased from 1000 for fewer chunks
123
+ chunk_overlap = 150 # Increased from 100 for better context
124
+ # Use 8 workers for parallel processing
125
+ text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, max_workers=8)
126
  text_chunks = text_splitter.split_texts(documents)
127
 
128
+ # Limit chunks to avoid processing too many for speed
129
+ max_chunks = 40 # Reduced from default
130
+ if len(text_chunks) > max_chunks:
131
+ logger.warning(f"Too many chunks ({len(text_chunks)}), limiting to {max_chunks} for faster processing")
132
+ text_chunks = text_chunks[:max_chunks]
133
+
134
  logger.info(f"Split document into {len(text_chunks)} chunks")
135
  return text_chunks
136
 
 
140
  return [f"Error processing file: {str(e)}"]
141
 
142
  async def setup_vector_db(texts: List[str]) -> VectorDatabase:
143
+ """Create vector database from text chunks - optimized with parallel processing"""
144
  logger.info(f"Setting up vector database with {len(texts)} text chunks")
145
 
146
+ # Create embedding model to use with VectorDatabase
147
  embedding_model = EmbeddingModel()
148
+ # Use batch size of 20 for better parallelization
149
+ vector_db = VectorDatabase(embedding_model=embedding_model, batch_size=20)
150
 
151
  try:
152
+ # Limit number of chunks for faster processing
153
+ max_chunks = 40
154
+ if len(texts) > max_chunks:
155
+ logger.warning(f"Limiting {len(texts)} chunks to {max_chunks} for vector embedding")
156
+ texts = texts[:max_chunks]
157
+
158
+ # Build vector database with batch processing
159
+ logger.info("Building vector database with batch processing")
160
  await vector_db.abuild_from_list(texts)
161
 
162
+ # Add documents property for compatibility
163
  vector_db.documents = texts
164
 
165
  logger.info(f"Vector database built with {len(texts)} documents")
166
  return vector_db
167
+ except asyncio.TimeoutError:
168
+ logger.error(f"Vector database creation timed out after 300 seconds")
169
+ # Create minimal fallback DB with just a few documents
170
+ fallback_db = VectorDatabase(embedding_model=embedding_model)
171
+ if texts:
172
+ # Use just first few texts for minimal functionality
173
+ minimal_texts = texts[:3]
174
+ for text in minimal_texts:
175
+ fallback_db.insert(text, [0.0] * 1536) # Use zero vectors for speed
176
+ fallback_db.documents = minimal_texts
177
+ else:
178
+ error_text = "I'm sorry, but there was a timeout during document processing."
179
+ fallback_db.insert(error_text, [0.0] * 1536)
180
+ fallback_db.documents = [error_text]
181
+ return fallback_db
182
  except Exception as e:
183
  logger.error(f"Error setting up vector database: {str(e)}")
184
  logger.error(traceback.format_exc())
185
 
186
+ # Create fallback DB for this error case
187
  fallback_db = VectorDatabase(embedding_model=embedding_model)
188
  error_text = "I'm sorry, but there was an error processing the document."
189
  fallback_db.insert(error_text, [0.0] * 1536)
frontend/src/App.js CHANGED
@@ -148,6 +148,8 @@ function FileUploader({ onFileUpload }) {
148
  const [isUploading, setIsUploading] = useState(false);
149
  const [uploadProgress, setUploadProgress] = useState(0);
150
  const [processingStatus, setProcessingStatus] = useState(null);
 
 
151
 
152
  const { getRootProps, getInputProps } = useDropzone({
153
  maxFiles: 1,
@@ -294,13 +296,71 @@ function FileUploader({ onFileUpload }) {
294
  }
295
  });
296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  // Status message based on current processing state
298
  const getStatusMessage = () => {
 
 
 
299
  switch(processingStatus) {
300
  case 'starting':
301
  return 'Initiating hyperspace jump...';
 
 
302
  case 'processing':
303
- return 'The Force is analyzing your document... This may take several minutes.';
304
  case 'timeout':
305
  return 'Document processing is taking longer than expected. Patience, young Padawan...';
306
  case 'failed':
@@ -335,7 +395,7 @@ function FileUploader({ onFileUpload }) {
335
  <>
336
  <Text color="brand.500">Uploading to the Jedi Archives...</Text>
337
  <Progress
338
- value={uploadProgress}
339
  size="sm"
340
  colorScheme="yellow"
341
  width="100%"
@@ -370,37 +430,12 @@ function App() {
370
  const handleFileUpload = (newSessionId, name) => {
371
  setSessionId(newSessionId);
372
  setFileName(name);
373
- setIsDocProcessing(true);
374
  setMessages([
375
- { text: `Processing ${name}. May the Force be with you...`, isUser: false }
376
  ]);
377
 
378
- // Poll for document processing status
379
- const checkStatus = async () => {
380
- try {
381
- const response = await axios.get(`${API_URL}/session/${newSessionId}/status`);
382
- console.log('Status response:', response.data);
383
-
384
- if (response.data.status === 'ready') {
385
- setIsDocProcessing(false);
386
- setMessages([
387
- { text: `"${name}" has been added to the Jedi Archives. What knowledge do you seek?`, isUser: false }
388
- ]);
389
- return;
390
- }
391
-
392
- // Continue polling if still processing
393
- if (response.data.status === 'processing') {
394
- setTimeout(checkStatus, 2000);
395
- }
396
- } catch (error) {
397
- console.error('Error checking status:', error);
398
- // Continue polling even if there's an error
399
- setTimeout(checkStatus, 3000);
400
- }
401
- };
402
-
403
- checkStatus();
404
  };
405
 
406
  const handleSendMessage = async () => {
 
148
  const [isUploading, setIsUploading] = useState(false);
149
  const [uploadProgress, setUploadProgress] = useState(0);
150
  const [processingStatus, setProcessingStatus] = useState(null);
151
+ const [processingProgress, setProcessingProgress] = useState(0);
152
+ const [processingSteps, setProcessingSteps] = useState(0);
153
 
154
  const { getRootProps, getInputProps } = useDropzone({
155
  maxFiles: 1,
 
296
  }
297
  });
298
 
299
+ // Move pollSessionStatus inside the component where it has access to the necessary variables
300
+ const pollSessionStatus = async (sessionId, file, retries = 40, interval = 5000) => {
301
+ // Increased retries from 30 to 40 for longer processing documents
302
+ let currentRetry = 0;
303
+
304
+ while (currentRetry < retries) {
305
+ try {
306
+ const statusUrl = `${API_URL}/session/${sessionId}/status`;
307
+ console.log(`Checking status (attempt ${currentRetry + 1}/${retries}):`, statusUrl);
308
+
309
+ const statusResponse = await axios.get(statusUrl, {
310
+ timeout: 30000 // 30 second timeout for status checks
311
+ });
312
+
313
+ console.log('Status response:', statusResponse.data);
314
+
315
+ if (statusResponse.data.status === 'ready') {
316
+ setProcessingStatus('complete');
317
+ setProcessingProgress(100);
318
+ onFileUpload(sessionId, file.name);
319
+ return;
320
+ } else if (statusResponse.data.status === 'failed') {
321
+ setProcessingStatus('failed');
322
+ throw new Error('Processing failed on server');
323
+ }
324
+
325
+ // Still processing, update progress based on attempt number
326
+ setProcessingStatus('processing');
327
+ // Calculate progress - more rapid at start, slower towards end
328
+ const progressIncrement = 75 / retries; // Max out at 75% during polling
329
+ setProcessingProgress(Math.min(5 + (currentRetry * progressIncrement), 75));
330
+
331
+ // Increment processing steps to show activity
332
+ setProcessingSteps(prev => prev + 1);
333
+
334
+ await new Promise(resolve => setTimeout(resolve, interval));
335
+ currentRetry++;
336
+
337
+ // Increase interval slightly for each retry to prevent overwhelming the server
338
+ interval = Math.min(interval * 1.1, 15000); // Cap at 15 seconds
339
+ } catch (error) {
340
+ console.error('Error checking status:', error);
341
+
342
+ // If we hit a timeout or network issue, wait a bit longer before retrying
343
+ await new Promise(resolve => setTimeout(resolve, interval * 2));
344
+ currentRetry++;
345
+ }
346
+ }
347
+
348
+ // If we've exhausted all retries and still don't have a ready status
349
+ throw new Error('Status polling timed out');
350
+ };
351
+
352
  // Status message based on current processing state
353
  const getStatusMessage = () => {
354
+ const steps = ['Analyzing text', 'Splitting document', 'Creating embeddings', 'Building vector database', 'Finalizing'];
355
+ const currentStep = steps[processingSteps % steps.length];
356
+
357
  switch(processingStatus) {
358
  case 'starting':
359
  return 'Initiating hyperspace jump...';
360
+ case 'uploading':
361
+ return 'Sending document to the Jedi Archives...';
362
  case 'processing':
363
+ return `${currentStep}... This may take several minutes.`;
364
  case 'timeout':
365
  return 'Document processing is taking longer than expected. Patience, young Padawan...';
366
  case 'failed':
 
395
  <>
396
  <Text color="brand.500">Uploading to the Jedi Archives...</Text>
397
  <Progress
398
+ value={processingStatus === 'uploading' ? uploadProgress : processingProgress}
399
  size="sm"
400
  colorScheme="yellow"
401
  width="100%"
 
430
  const handleFileUpload = (newSessionId, name) => {
431
  setSessionId(newSessionId);
432
  setFileName(name);
433
+ setIsDocProcessing(false);
434
  setMessages([
435
+ { text: `"${name}" has been added to the Jedi Archives. What knowledge do you seek?`, isUser: false }
436
  ]);
437
 
438
+ // Don't poll again - already handled in FileUploader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  };
440
 
441
  const handleSendMessage = async () => {