“Transcendental-Programmer” commited on
Commit
41d470a
·
1 Parent(s): 754afec

FEAT: Added testing modules

Browse files
Files changed (2) hide show
  1. tests/test_rag.py +93 -0
  2. tests/test_server.py +57 -0
tests/test_rag.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """test_rag.py module."""
2
+
3
+ import pytest
4
+ from src.rag.retriever import FinancialDataRetriever
5
+ from src.rag.generator import RAGGenerator
6
+ import yaml
7
+
8
+ @pytest.fixture
9
+ def rag_config():
10
+ with open('config/server_config.yaml', 'r') as f:
11
+ config = yaml.safe_load(f)
12
+ config['rag'] = {
13
+ 'retriever': 'faiss',
14
+ 'max_documents': 5,
15
+ 'similarity_threshold': 0.7
16
+ }
17
+ return config
18
+
19
+ @pytest.fixture
20
+ def retriever(rag_config):
21
+ return FinancialDataRetriever(rag_config)
22
+
23
+ @pytest.fixture
24
+ def generator(rag_config):
25
+ return RAGGenerator(rag_config)
26
+
27
+ def test_retriever_initialization(retriever, rag_config):
28
+ assert retriever.retriever_type == rag_config['rag']['retriever']
29
+ assert retriever.max_documents == rag_config['rag']['max_documents']
30
+
31
+ def test_document_indexing(retriever):
32
+ test_documents = [
33
+ {'text': 'Financial report 2023', 'id': 1},
34
+ {'text': 'Market analysis Q4', 'id': 2},
35
+ {'text': 'Investment strategy', 'id': 3}
36
+ ]
37
+
38
+ retriever.index_documents(test_documents)
39
+ assert retriever.index.ntotal == len(test_documents)
40
+
41
+ def test_document_retrieval(retriever):
42
+ # Index test documents
43
+ test_documents = [
44
+ {'text': 'Financial report 2023', 'id': 1},
45
+ {'text': 'Market analysis Q4', 'id': 2},
46
+ {'text': 'Investment strategy', 'id': 3}
47
+ ]
48
+ retriever.index_documents(test_documents)
49
+
50
+ # Test retrieval
51
+ query = "financial report"
52
+ results = retriever.retrieve(query)
53
+ assert len(results) > 0
54
+ assert all('document' in result for result in results)
55
+ assert all('score' in result for result in results)
56
+
57
+ def test_generator_initialization(generator):
58
+ assert hasattr(generator, 'model')
59
+ assert hasattr(generator, 'tokenizer')
60
+
61
+ def test_text_generation(generator):
62
+ retrieved_docs = [
63
+ {
64
+ 'document': {'text': 'Financial market analysis shows positive trends'},
65
+ 'score': 0.9
66
+ }
67
+ ]
68
+
69
+ generated_text = generator.generate(
70
+ query="Summarize market trends",
71
+ retrieved_docs=retrieved_docs
72
+ )
73
+
74
+ assert isinstance(generated_text, str)
75
+ assert len(generated_text) > 0
76
+
77
+ def test_context_preparation(generator):
78
+ retrieved_docs = [
79
+ {
80
+ 'document': {'text': 'Doc 1 content'},
81
+ 'score': 0.9
82
+ },
83
+ {
84
+ 'document': {'text': 'Doc 2 content'},
85
+ 'score': 0.8
86
+ }
87
+ ]
88
+
89
+ context = generator.prepare_context(retrieved_docs)
90
+ assert isinstance(context, str)
91
+ assert 'Doc 1 content' in context
92
+ assert 'Doc 2 content' in context
93
+
tests/test_server.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """test_server.py module."""
2
+
3
+ import pytest
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from src.server.coordinator import FederatedCoordinator
7
+ from src.server.aggregator import FederatedAggregator
8
+ import yaml
9
+
10
+ @pytest.fixture
11
+ def server_config():
12
+ with open('config/server_config.yaml', 'r') as f:
13
+ return yaml.safe_load(f)['server']
14
+
15
+ @pytest.fixture
16
+ def coordinator(server_config):
17
+ return FederatedCoordinator(server_config)
18
+
19
+ @pytest.fixture
20
+ def aggregator(server_config):
21
+ return FederatedAggregator(server_config)
22
+
23
+ def test_coordinator_initialization(coordinator, server_config):
24
+ assert coordinator.min_clients == server_config['federated']['min_clients']
25
+ assert coordinator.rounds == server_config['federated']['rounds']
26
+ assert coordinator.sample_fraction == server_config['federated']['sample_fraction']
27
+
28
+ def test_client_registration(coordinator):
29
+ client_id = 1
30
+ client_size = 1000
31
+ coordinator.register_client(client_id, client_size)
32
+ assert client_id in coordinator.clients
33
+ assert coordinator.clients[client_id]['size'] == client_size
34
+
35
+ def test_client_selection(coordinator):
36
+ # Register multiple clients
37
+ for i in range(5):
38
+ coordinator.register_client(i, 1000)
39
+
40
+ selected_clients = coordinator.select_clients()
41
+ assert len(selected_clients) >= coordinator.min_clients
42
+ assert all(client_id in coordinator.clients for client_id in selected_clients)
43
+
44
+ def test_weight_aggregation(aggregator):
45
+ # Create mock client updates
46
+ client_updates = [
47
+ {
48
+ 'client_id': i,
49
+ 'weights': [np.random.randn(10, 10) for _ in range(3)],
50
+ 'metrics': {'loss': 0.5}
51
+ }
52
+ for i in range(3)
53
+ ]
54
+
55
+ aggregated_weights = aggregator.compute_metrics(client_updates)
56
+ assert isinstance(aggregated_weights, dict)
57
+