mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-24 06:19:37 +00:00
Refactor: Restructure collab main function to prepare for new subcommand: serve llm
(#15824)
This is just a refactor that we're landing ahead of any functional changes to make sure we haven't broken anything. Release Notes: - N/A Co-authored-by: Marshall <marshall@zed.dev> Co-authored-by: Jason <jason@zed.dev>
This commit is contained in:
parent
705f7e7a03
commit
27779e33fb
5 changed files with 142 additions and 85 deletions
2
Procfile
2
Procfile
|
@ -1,3 +1,3 @@
|
||||||
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve
|
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
|
||||||
livekit: livekit-server --dev
|
livekit: livekit-server --dev
|
||||||
blob_store: ./script/run-local-minio
|
blob_store: ./script/run-local-minio
|
||||||
|
|
|
@ -61,7 +61,7 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
|
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/user", get(get_authenticated_user))
|
.route("/user", get(get_authenticated_user))
|
||||||
.route("/users/:id/access_tokens", post(create_access_token))
|
.route("/users/:id/access_tokens", post(create_access_token))
|
||||||
|
@ -70,7 +70,6 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
|
||||||
.merge(contributors::router())
|
.merge(contributors::router())
|
||||||
.layer(
|
.layer(
|
||||||
ServiceBuilder::new()
|
ServiceBuilder::new()
|
||||||
.layer(Extension(state))
|
|
||||||
.layer(Extension(rpc_server))
|
.layer(Extension(rpc_server))
|
||||||
.layer(middleware::from_fn(validate_api_token)),
|
.layer(middleware::from_fn(validate_api_token)),
|
||||||
)
|
)
|
||||||
|
@ -152,12 +151,8 @@ struct CreateUserParams {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_rpc_server_snapshot(
|
async fn get_rpc_server_snapshot(
|
||||||
Extension(rpc_server): Extension<Option<Arc<rpc::Server>>>,
|
Extension(rpc_server): Extension<Arc<rpc::Server>>,
|
||||||
) -> Result<ErasedJson> {
|
) -> Result<ErasedJson> {
|
||||||
let Some(rpc_server) = rpc_server else {
|
|
||||||
return Err(Error::Internal(anyhow!("rpc server is not available")));
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ErasedJson::pretty(rpc_server.snapshot().await))
|
Ok(ErasedJson::pretty(rpc_server.snapshot().await))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ pub mod auth;
|
||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod env;
|
pub mod env;
|
||||||
pub mod executor;
|
pub mod executor;
|
||||||
|
pub mod llm;
|
||||||
mod rate_limiter;
|
mod rate_limiter;
|
||||||
pub mod rpc;
|
pub mod rpc;
|
||||||
pub mod seed;
|
pub mod seed;
|
||||||
|
@ -124,7 +125,7 @@ impl std::fmt::Display for Error {
|
||||||
|
|
||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Clone, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub http_port: u16,
|
pub http_port: u16,
|
||||||
pub database_url: String,
|
pub database_url: String,
|
||||||
|
@ -176,6 +177,29 @@ impl Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The service mode that collab should run in.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum ServiceMode {
|
||||||
|
Api,
|
||||||
|
Collab,
|
||||||
|
Llm,
|
||||||
|
All,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServiceMode {
|
||||||
|
pub fn is_collab(&self) -> bool {
|
||||||
|
matches!(self, Self::Collab | Self::All)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_api(&self) -> bool {
|
||||||
|
matches!(self, Self::Api | Self::All)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_llm(&self) -> bool {
|
||||||
|
matches!(self, Self::Llm | Self::All)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub db: Arc<Database>,
|
pub db: Arc<Database>,
|
||||||
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
|
||||||
|
|
16
crates/collab/src/llm.rs
Normal file
16
crates/collab/src/llm.rs
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::{executor::Executor, Config, Result};
|
||||||
|
|
||||||
|
pub struct LlmState {
|
||||||
|
pub config: Config,
|
||||||
|
pub executor: Executor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlmState {
|
||||||
|
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||||
|
let this = Self { config, executor };
|
||||||
|
|
||||||
|
Ok(Arc::new(this))
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,7 +5,7 @@ use axum::{
|
||||||
routing::get,
|
routing::get,
|
||||||
Extension, Router,
|
Extension, Router,
|
||||||
};
|
};
|
||||||
use collab::api::billing::poll_stripe_events_periodically;
|
use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
|
||||||
use collab::{
|
use collab::{
|
||||||
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
|
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
|
||||||
rpc::ResultExt, AppState, Config, RateLimiter, Result,
|
rpc::ResultExt, AppState, Config, RateLimiter, Result,
|
||||||
|
@ -56,88 +56,99 @@ async fn main() -> Result<()> {
|
||||||
collab::seed::seed(&config, &db, true).await?;
|
collab::seed::seed(&config, &db, true).await?;
|
||||||
}
|
}
|
||||||
Some("serve") => {
|
Some("serve") => {
|
||||||
let (is_api, is_collab) = if let Some(next) = args.next() {
|
let mode = match args.next().as_deref() {
|
||||||
(next == "api", next == "collab")
|
Some("collab") => ServiceMode::Collab,
|
||||||
} else {
|
Some("api") => ServiceMode::Api,
|
||||||
(true, true)
|
Some("llm") => ServiceMode::Llm,
|
||||||
|
Some("all") => ServiceMode::All,
|
||||||
|
_ => {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
|
||||||
|
))?;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
if !is_api && !is_collab {
|
|
||||||
Err(anyhow!(
|
|
||||||
"usage: collab <version | migrate | seed | serve [api|collab]>"
|
|
||||||
))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let config = envy::from_env::<Config>().expect("error loading config");
|
let config = envy::from_env::<Config>().expect("error loading config");
|
||||||
init_tracing(&config);
|
init_tracing(&config);
|
||||||
|
let mut app = Router::new()
|
||||||
|
.route("/", get(handle_root))
|
||||||
|
.route("/healthz", get(handle_liveness_probe))
|
||||||
|
.layer(Extension(mode));
|
||||||
|
|
||||||
run_migrations(&config).await?;
|
let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
|
||||||
|
|
||||||
let state = AppState::new(config, Executor::Production).await?;
|
|
||||||
|
|
||||||
let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
|
|
||||||
.expect("failed to bind TCP listener");
|
.expect("failed to bind TCP listener");
|
||||||
|
|
||||||
let rpc_server = if is_collab {
|
let mut on_shutdown = None;
|
||||||
let epoch = state
|
|
||||||
.db
|
|
||||||
.create_server(&state.config.zed_environment)
|
|
||||||
.await?;
|
|
||||||
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
|
||||||
rpc_server.start().await?;
|
|
||||||
|
|
||||||
Some(rpc_server)
|
if mode.is_llm() {
|
||||||
} else {
|
let state = LlmState::new(config.clone(), Executor::Production).await?;
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
if is_collab {
|
app = app.layer(Extension(state.clone()));
|
||||||
state.db.purge_old_embeddings().await.trace_err();
|
|
||||||
RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_api {
|
if mode.is_collab() || mode.is_api() {
|
||||||
poll_stripe_events_periodically(state.clone());
|
run_migrations(&config).await?;
|
||||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut app = collab::api::routes(rpc_server.clone(), state.clone());
|
let state = AppState::new(config, Executor::Production).await?;
|
||||||
if let Some(rpc_server) = rpc_server.clone() {
|
|
||||||
app = app.merge(collab::rpc::routes(rpc_server))
|
if mode.is_collab() {
|
||||||
}
|
state.db.purge_old_embeddings().await.trace_err();
|
||||||
app = app
|
RateLimiter::save_periodically(
|
||||||
.merge(
|
state.rate_limiter.clone(),
|
||||||
Router::new()
|
state.executor.clone(),
|
||||||
.route("/", get(handle_root))
|
);
|
||||||
.route("/healthz", get(handle_liveness_probe))
|
|
||||||
.merge(collab::api::extensions::router())
|
let epoch = state
|
||||||
|
.db
|
||||||
|
.create_server(&state.config.zed_environment)
|
||||||
|
.await?;
|
||||||
|
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
|
||||||
|
rpc_server.start().await?;
|
||||||
|
|
||||||
|
app = app
|
||||||
|
.merge(collab::api::routes(rpc_server.clone()))
|
||||||
|
.merge(collab::rpc::routes(rpc_server.clone()));
|
||||||
|
|
||||||
|
on_shutdown = Some(Box::new(move || rpc_server.teardown()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if mode.is_api() {
|
||||||
|
poll_stripe_events_periodically(state.clone());
|
||||||
|
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||||
|
|
||||||
|
app = app
|
||||||
.merge(collab::api::events::router())
|
.merge(collab::api::events::router())
|
||||||
.layer(Extension(state.clone())),
|
.merge(collab::api::extensions::router())
|
||||||
)
|
}
|
||||||
.layer(
|
|
||||||
TraceLayer::new_for_http()
|
|
||||||
.make_span_with(|request: &Request<_>| {
|
|
||||||
let matched_path = request
|
|
||||||
.extensions()
|
|
||||||
.get::<MatchedPath>()
|
|
||||||
.map(MatchedPath::as_str);
|
|
||||||
|
|
||||||
tracing::info_span!(
|
app = app.layer(Extension(state.clone()));
|
||||||
"http_request",
|
}
|
||||||
method = ?request.method(),
|
|
||||||
matched_path,
|
app = app.layer(
|
||||||
)
|
TraceLayer::new_for_http()
|
||||||
})
|
.make_span_with(|request: &Request<_>| {
|
||||||
.on_response(
|
let matched_path = request
|
||||||
|response: &Response<_>, latency: Duration, _: &tracing::Span| {
|
.extensions()
|
||||||
let duration_ms = latency.as_micros() as f64 / 1000.;
|
.get::<MatchedPath>()
|
||||||
tracing::info!(
|
.map(MatchedPath::as_str);
|
||||||
duration_ms,
|
|
||||||
status = response.status().as_u16(),
|
tracing::info_span!(
|
||||||
"finished processing request"
|
"http_request",
|
||||||
);
|
method = ?request.method(),
|
||||||
},
|
matched_path,
|
||||||
),
|
)
|
||||||
);
|
})
|
||||||
|
.on_response(
|
||||||
|
|response: &Response<_>, latency: Duration, _: &tracing::Span| {
|
||||||
|
let duration_ms = latency.as_micros() as f64 / 1000.;
|
||||||
|
tracing::info!(
|
||||||
|
duration_ms,
|
||||||
|
status = response.status().as_u16(),
|
||||||
|
"finished processing request"
|
||||||
|
);
|
||||||
|
},
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
let signal = async move {
|
let signal = async move {
|
||||||
|
@ -174,8 +185,8 @@ async fn main() -> Result<()> {
|
||||||
signal.await;
|
signal.await;
|
||||||
tracing::info!("Received interrupt signal");
|
tracing::info!("Received interrupt signal");
|
||||||
|
|
||||||
if let Some(rpc_server) = rpc_server {
|
if let Some(on_shutdown) = on_shutdown {
|
||||||
rpc_server.teardown();
|
on_shutdown();
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
@ -183,7 +194,7 @@ async fn main() -> Result<()> {
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
Err(anyhow!(
|
Err(anyhow!(
|
||||||
"usage: collab <version | migrate | seed | serve [api|collab]>"
|
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
|
||||||
))?;
|
))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -222,12 +233,23 @@ async fn run_migrations(config: &Config) -> Result<()> {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_root() -> String {
|
async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
|
||||||
format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
|
format!(
|
||||||
|
"collab {mode:?} v{VERSION} ({})",
|
||||||
|
REVISION.unwrap_or("unknown")
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
|
async fn handle_liveness_probe(
|
||||||
state.db.get_all_users(0, 1).await?;
|
app_state: Option<Extension<Arc<AppState>>>,
|
||||||
|
llm_state: Option<Extension<Arc<LlmState>>>,
|
||||||
|
) -> Result<String> {
|
||||||
|
if let Some(state) = app_state {
|
||||||
|
state.db.get_all_users(0, 1).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(_llm_state) = llm_state {}
|
||||||
|
|
||||||
Ok("ok".to_string())
|
Ok("ok".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue