mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-10 20:41:59 +00:00
Parallel vector db (#2792)
Parallelize Vector Database calls for project semantic search. Release Notes: (Preview-only) - Parallelize Vector database calls for project semantic search. Cuts query time by 2/3rds. - Removed default keymap for old semantic search modal.
This commit is contained in:
commit
ee66f99ce6
3 changed files with 84 additions and 46 deletions
|
@ -411,7 +411,6 @@
|
|||
"cmd-k cmd-t": "theme_selector::Toggle",
|
||||
"cmd-k cmd-s": "zed::OpenKeymap",
|
||||
"cmd-t": "project_symbols::Toggle",
|
||||
"cmd-ctrl-t": "semantic_search::Toggle",
|
||||
"cmd-p": "file_finder::Toggle",
|
||||
"cmd-shift-p": "command_palette::Toggle",
|
||||
"cmd-shift-m": "diagnostics::Deploy",
|
||||
|
|
|
@ -267,41 +267,32 @@ impl VectorDatabase {
|
|||
|
||||
pub fn top_k_search(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
query_embedding: &Vec<f32>,
|
||||
limit: usize,
|
||||
include_globs: Vec<GlobMatcher>,
|
||||
exclude_globs: Vec<GlobMatcher>,
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
file_ids: &[i64],
|
||||
) -> Result<Vec<(i64, f32)>> {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
self.for_each_document(
|
||||
&worktree_ids,
|
||||
include_globs,
|
||||
exclude_globs,
|
||||
|id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
},
|
||||
)?;
|
||||
self.for_each_document(file_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
let ix = match results
|
||||
.binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
|
||||
{
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
})?;
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
|
||||
self.get_documents_by_ids(&ids)
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn for_each_document(
|
||||
pub fn retrieve_included_file_ids(
|
||||
&self,
|
||||
worktree_ids: &[i64],
|
||||
include_globs: Vec<GlobMatcher>,
|
||||
exclude_globs: Vec<GlobMatcher>,
|
||||
mut f: impl FnMut(i64, Vec<f32>),
|
||||
) -> Result<()> {
|
||||
) -> Result<Vec<i64>> {
|
||||
let mut file_query = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
|
@ -315,6 +306,7 @@ impl VectorDatabase {
|
|||
|
||||
let mut file_ids = Vec::<i64>::new();
|
||||
let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
|
||||
|
||||
while let Some(row) = rows.next()? {
|
||||
let file_id = row.get(0)?;
|
||||
let relative_path = row.get_ref(1)?.as_str()?;
|
||||
|
@ -330,6 +322,10 @@ impl VectorDatabase {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(file_ids)
|
||||
}
|
||||
|
||||
fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
|
||||
let mut query_statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
|
@ -350,7 +346,7 @@ impl VectorDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
|
|
|
@ -20,6 +20,7 @@ use postage::watch;
|
|||
use project::{Fs, Project, WorktreeId};
|
||||
use smol::channel;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::HashMap,
|
||||
mem,
|
||||
ops::Range,
|
||||
|
@ -704,27 +705,69 @@ impl SemanticIndex {
|
|||
let database_url = self.database_url.clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let documents = cx
|
||||
.background()
|
||||
.spawn(async move {
|
||||
let database = VectorDatabase::new(fs, database_url).await?;
|
||||
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
|
||||
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![&phrase])
|
||||
.await?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
let phrase_embedding = embedding_provider
|
||||
.embed_batch(vec![&phrase])
|
||||
.await?
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
|
||||
database.top_k_search(
|
||||
&worktree_db_ids,
|
||||
&phrase_embedding,
|
||||
limit,
|
||||
include_globs,
|
||||
exclude_globs,
|
||||
)
|
||||
})
|
||||
.await?;
|
||||
let file_ids = database.retrieve_included_file_ids(
|
||||
&worktree_db_ids,
|
||||
include_globs,
|
||||
exclude_globs,
|
||||
)?;
|
||||
|
||||
let batch_n = cx.background().num_cpus();
|
||||
let ids_len = file_ids.clone().len();
|
||||
let batch_size = if ids_len <= batch_n {
|
||||
ids_len
|
||||
} else {
|
||||
ids_len / batch_n
|
||||
};
|
||||
|
||||
let mut result_tasks = Vec::new();
|
||||
for batch in file_ids.chunks(batch_size) {
|
||||
let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
|
||||
let limit = limit.clone();
|
||||
let fs = fs.clone();
|
||||
let database_url = database_url.clone();
|
||||
let phrase_embedding = phrase_embedding.clone();
|
||||
let task = cx.background().spawn(async move {
|
||||
let database = VectorDatabase::new(fs, database_url).await.log_err();
|
||||
if database.is_none() {
|
||||
return Err(anyhow!("failed to acquire database connection"));
|
||||
} else {
|
||||
database
|
||||
.unwrap()
|
||||
.top_k_search(&phrase_embedding, limit, batch.as_slice())
|
||||
}
|
||||
});
|
||||
result_tasks.push(task);
|
||||
}
|
||||
|
||||
let batch_results = futures::future::join_all(result_tasks).await;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for batch_result in batch_results {
|
||||
if batch_result.is_ok() {
|
||||
for (id, similarity) in batch_result.unwrap() {
|
||||
let ix = match results.binary_search_by(|(_, s)| {
|
||||
similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
|
||||
}) {
|
||||
Ok(ix) => ix,
|
||||
Err(ix) => ix,
|
||||
};
|
||||
results.insert(ix, (id, similarity));
|
||||
results.truncate(limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
|
||||
let documents = database.get_documents_by_ids(ids.as_slice())?;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
let mut ranges = Vec::new();
|
||||
|
|
Loading…
Reference in a new issue