File size: 3,190 Bytes
f2e1fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// Handles cluster plot (UMAP scatter) event
import { getLineEmbeddings } from './embedding.js';
import { plotScatter } from './plotting.js';

const task = "Given a textual input sentence, retrieve relevant categories that best describe it.";

export async function handleClusterPlotEvent() {
    const progressBar = document.getElementById("progress-bar");
    const progressBarInner = document.getElementById("progress-bar-inner");
    progressBar.style.display = "block";
    progressBarInner.style.width = "0%";

    // Recalculate embeddings from current textarea
    const text = document.getElementById("input").value;
    // Remove ## lines for embedding
    const lines = text.split(/\n/).map(x => x.trim()).filter(x => x && !x.startsWith("##"));
    const embeddings = await getLineEmbeddings(lines, task);
    const n = embeddings.length;
    if (n < 2) return;

    // Parse clusters from textarea (split by triple newlines)
    const groups = text.split(/\n{3,}/);
    const k = groups.length;
    // Build labels array: for each line, assign the cluster index it belongs to
    let labels = [];
    let lineIdx = 0;
    for (let c = 0; c < k; ++c) {
        const groupLines = groups[c].split('\n').map(x => x.trim()).filter(x => x && !x.startsWith('##'));
        for (let i = 0; i < groupLines.length; ++i) {
            labels[lineIdx++] = c;
        }
    }
    if (labels.length !== n) return;

    // UMAP projection
    const { UMAP } = await import('https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm');
    const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));
    const umap = new UMAP({ nComponents: 2, nNeighbors, minDist: 0.2, metric: "cosine" });
    const proj = umap.fit(embeddings);
    // Group lines by cluster
    const clustered = Array.from({ length: k }, () => []);
    for (let i = 0; i < lines.length; ++i)
        clustered[labels[i]].push(lines[i]);
    // Prepare scatter plot traces
    const colors = ["red", "blue", "green", "orange", "purple", "cyan", "magenta", "yellow", "brown", "black", "lime", "navy", "teal", "olive", "maroon", "pink", "gray", "gold", "aqua", "indigo"];
    // Try to extract cluster names from textarea headers
    const clusterNames = groups.map(g => {
        const m = g.match(/^##\s*(.*)/m);
        return m ? m[1].trim() : null;
    });
    const placeholderNames = clusterNames.map((name, i) => name || `Cluster ${i + 1}`);
    const traces = Array.from({ length: k }, (_, c) => ({
        x: [], y: [], text: [],
        mode: "markers", type: "scatter",
        name: placeholderNames[c],
        marker: { color: colors[c % colors.length], size: 12, line: { width: 1, color: "#333" } }
    }));
    for (let i = 0; i < lines.length; ++i) {
        traces[labels[i]].x.push(proj[i][0]);
        traces[labels[i]].y.push(proj[i][1]);
        traces[labels[i]].text.push(lines[i]);
    }
    plotScatter(traces, k);
    window.traces = traces;
    // Optionally update textarea with cluster names as markdown headers
    document.getElementById("input").value = clustered.map((g, i) =>
        `## ${placeholderNames[i]}\n${g.join("\n")}`
    ).join("\n\n\n");
    progressBarInner.style.width = "100%";
}