Spaces:
Running
Running
File size: 3,190 Bytes
f2e1fb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
// Handles cluster plot (UMAP scatter) event
import { getLineEmbeddings } from './embedding.js';
import { plotScatter } from './plotting.js';
const task = "Given a textual input sentence, retrieve relevant categories that best describe it.";
export async function handleClusterPlotEvent() {
const progressBar = document.getElementById("progress-bar");
const progressBarInner = document.getElementById("progress-bar-inner");
progressBar.style.display = "block";
progressBarInner.style.width = "0%";
// Recalculate embeddings from current textarea
const text = document.getElementById("input").value;
// Remove ## lines for embedding
const lines = text.split(/\n/).map(x => x.trim()).filter(x => x && !x.startsWith("##"));
const embeddings = await getLineEmbeddings(lines, task);
const n = embeddings.length;
if (n < 2) return;
// Parse clusters from textarea (split by triple newlines)
const groups = text.split(/\n{3,}/);
const k = groups.length;
// Build labels array: for each line, assign the cluster index it belongs to
let labels = [];
let lineIdx = 0;
for (let c = 0; c < k; ++c) {
const groupLines = groups[c].split('\n').map(x => x.trim()).filter(x => x && !x.startsWith('##'));
for (let i = 0; i < groupLines.length; ++i) {
labels[lineIdx++] = c;
}
}
if (labels.length !== n) return;
// UMAP projection
const { UMAP } = await import('https://cdn.jsdelivr.net/npm/umap-js@1.4.0/+esm');
const nNeighbors = Math.max(1, Math.min(lines.length - 1, 15));
const umap = new UMAP({ nComponents: 2, nNeighbors, minDist: 0.2, metric: "cosine" });
const proj = umap.fit(embeddings);
// Group lines by cluster
const clustered = Array.from({ length: k }, () => []);
for (let i = 0; i < lines.length; ++i)
clustered[labels[i]].push(lines[i]);
// Prepare scatter plot traces
const colors = ["red", "blue", "green", "orange", "purple", "cyan", "magenta", "yellow", "brown", "black", "lime", "navy", "teal", "olive", "maroon", "pink", "gray", "gold", "aqua", "indigo"];
// Try to extract cluster names from textarea headers
const clusterNames = groups.map(g => {
const m = g.match(/^##\s*(.*)/m);
return m ? m[1].trim() : null;
});
const placeholderNames = clusterNames.map((name, i) => name || `Cluster ${i + 1}`);
const traces = Array.from({ length: k }, (_, c) => ({
x: [], y: [], text: [],
mode: "markers", type: "scatter",
name: placeholderNames[c],
marker: { color: colors[c % colors.length], size: 12, line: { width: 1, color: "#333" } }
}));
for (let i = 0; i < lines.length; ++i) {
traces[labels[i]].x.push(proj[i][0]);
traces[labels[i]].y.push(proj[i][1]);
traces[labels[i]].text.push(lines[i]);
}
plotScatter(traces, k);
window.traces = traces;
// Optionally update textarea with cluster names as markdown headers
document.getElementById("input").value = clustered.map((g, i) =>
`## ${placeholderNames[i]}\n${g.join("\n")}`
).join("\n\n\n");
progressBarInner.style.width = "100%";
}
|