“Transcendental-Programmer”
commited on
Commit
·
41d470a
1
Parent(s):
754afec
FEAT: Added testing modules
Browse files- tests/test_rag.py +93 -0
- tests/test_server.py +57 -0
tests/test_rag.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""test_rag.py module."""
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
from src.rag.retriever import FinancialDataRetriever
|
5 |
+
from src.rag.generator import RAGGenerator
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
@pytest.fixture
|
9 |
+
def rag_config():
|
10 |
+
with open('config/server_config.yaml', 'r') as f:
|
11 |
+
config = yaml.safe_load(f)
|
12 |
+
config['rag'] = {
|
13 |
+
'retriever': 'faiss',
|
14 |
+
'max_documents': 5,
|
15 |
+
'similarity_threshold': 0.7
|
16 |
+
}
|
17 |
+
return config
|
18 |
+
|
19 |
+
@pytest.fixture
|
20 |
+
def retriever(rag_config):
|
21 |
+
return FinancialDataRetriever(rag_config)
|
22 |
+
|
23 |
+
@pytest.fixture
|
24 |
+
def generator(rag_config):
|
25 |
+
return RAGGenerator(rag_config)
|
26 |
+
|
27 |
+
def test_retriever_initialization(retriever, rag_config):
|
28 |
+
assert retriever.retriever_type == rag_config['rag']['retriever']
|
29 |
+
assert retriever.max_documents == rag_config['rag']['max_documents']
|
30 |
+
|
31 |
+
def test_document_indexing(retriever):
|
32 |
+
test_documents = [
|
33 |
+
{'text': 'Financial report 2023', 'id': 1},
|
34 |
+
{'text': 'Market analysis Q4', 'id': 2},
|
35 |
+
{'text': 'Investment strategy', 'id': 3}
|
36 |
+
]
|
37 |
+
|
38 |
+
retriever.index_documents(test_documents)
|
39 |
+
assert retriever.index.ntotal == len(test_documents)
|
40 |
+
|
41 |
+
def test_document_retrieval(retriever):
|
42 |
+
# Index test documents
|
43 |
+
test_documents = [
|
44 |
+
{'text': 'Financial report 2023', 'id': 1},
|
45 |
+
{'text': 'Market analysis Q4', 'id': 2},
|
46 |
+
{'text': 'Investment strategy', 'id': 3}
|
47 |
+
]
|
48 |
+
retriever.index_documents(test_documents)
|
49 |
+
|
50 |
+
# Test retrieval
|
51 |
+
query = "financial report"
|
52 |
+
results = retriever.retrieve(query)
|
53 |
+
assert len(results) > 0
|
54 |
+
assert all('document' in result for result in results)
|
55 |
+
assert all('score' in result for result in results)
|
56 |
+
|
57 |
+
def test_generator_initialization(generator):
|
58 |
+
assert hasattr(generator, 'model')
|
59 |
+
assert hasattr(generator, 'tokenizer')
|
60 |
+
|
61 |
+
def test_text_generation(generator):
|
62 |
+
retrieved_docs = [
|
63 |
+
{
|
64 |
+
'document': {'text': 'Financial market analysis shows positive trends'},
|
65 |
+
'score': 0.9
|
66 |
+
}
|
67 |
+
]
|
68 |
+
|
69 |
+
generated_text = generator.generate(
|
70 |
+
query="Summarize market trends",
|
71 |
+
retrieved_docs=retrieved_docs
|
72 |
+
)
|
73 |
+
|
74 |
+
assert isinstance(generated_text, str)
|
75 |
+
assert len(generated_text) > 0
|
76 |
+
|
77 |
+
def test_context_preparation(generator):
|
78 |
+
retrieved_docs = [
|
79 |
+
{
|
80 |
+
'document': {'text': 'Doc 1 content'},
|
81 |
+
'score': 0.9
|
82 |
+
},
|
83 |
+
{
|
84 |
+
'document': {'text': 'Doc 2 content'},
|
85 |
+
'score': 0.8
|
86 |
+
}
|
87 |
+
]
|
88 |
+
|
89 |
+
context = generator.prepare_context(retrieved_docs)
|
90 |
+
assert isinstance(context, str)
|
91 |
+
assert 'Doc 1 content' in context
|
92 |
+
assert 'Doc 2 content' in context
|
93 |
+
|
tests/test_server.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""test_server.py module."""
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow as tf
|
6 |
+
from src.server.coordinator import FederatedCoordinator
|
7 |
+
from src.server.aggregator import FederatedAggregator
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
@pytest.fixture
|
11 |
+
def server_config():
|
12 |
+
with open('config/server_config.yaml', 'r') as f:
|
13 |
+
return yaml.safe_load(f)['server']
|
14 |
+
|
15 |
+
@pytest.fixture
|
16 |
+
def coordinator(server_config):
|
17 |
+
return FederatedCoordinator(server_config)
|
18 |
+
|
19 |
+
@pytest.fixture
|
20 |
+
def aggregator(server_config):
|
21 |
+
return FederatedAggregator(server_config)
|
22 |
+
|
23 |
+
def test_coordinator_initialization(coordinator, server_config):
|
24 |
+
assert coordinator.min_clients == server_config['federated']['min_clients']
|
25 |
+
assert coordinator.rounds == server_config['federated']['rounds']
|
26 |
+
assert coordinator.sample_fraction == server_config['federated']['sample_fraction']
|
27 |
+
|
28 |
+
def test_client_registration(coordinator):
|
29 |
+
client_id = 1
|
30 |
+
client_size = 1000
|
31 |
+
coordinator.register_client(client_id, client_size)
|
32 |
+
assert client_id in coordinator.clients
|
33 |
+
assert coordinator.clients[client_id]['size'] == client_size
|
34 |
+
|
35 |
+
def test_client_selection(coordinator):
|
36 |
+
# Register multiple clients
|
37 |
+
for i in range(5):
|
38 |
+
coordinator.register_client(i, 1000)
|
39 |
+
|
40 |
+
selected_clients = coordinator.select_clients()
|
41 |
+
assert len(selected_clients) >= coordinator.min_clients
|
42 |
+
assert all(client_id in coordinator.clients for client_id in selected_clients)
|
43 |
+
|
44 |
+
def test_weight_aggregation(aggregator):
|
45 |
+
# Create mock client updates
|
46 |
+
client_updates = [
|
47 |
+
{
|
48 |
+
'client_id': i,
|
49 |
+
'weights': [np.random.randn(10, 10) for _ in range(3)],
|
50 |
+
'metrics': {'loss': 0.5}
|
51 |
+
}
|
52 |
+
for i in range(3)
|
53 |
+
]
|
54 |
+
|
55 |
+
aggregated_weights = aggregator.compute_metrics(client_updates)
|
56 |
+
assert isinstance(aggregated_weights, dict)
|
57 |
+
|