Refactor: Restructure collab main function to prepare for new subcommand: serve llm (#15824)

This is just a refactor that we're landing ahead of any functional
changes to make sure we haven't broken anything.

Release Notes:

- N/A

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Jason <jason@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-08-05 12:07:38 -07:00 committed by GitHub
parent 705f7e7a03
commit 27779e33fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 142 additions and 85 deletions

View file

@ -1,3 +1,3 @@
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
livekit: livekit-server --dev livekit: livekit-server --dev
blob_store: ./script/run-local-minio blob_store: ./script/run-local-minio

View file

@ -61,7 +61,7 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
} }
} }
pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> { pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new() Router::new()
.route("/user", get(get_authenticated_user)) .route("/user", get(get_authenticated_user))
.route("/users/:id/access_tokens", post(create_access_token)) .route("/users/:id/access_tokens", post(create_access_token))
@ -70,7 +70,6 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
.merge(contributors::router()) .merge(contributors::router())
.layer( .layer(
ServiceBuilder::new() ServiceBuilder::new()
.layer(Extension(state))
.layer(Extension(rpc_server)) .layer(Extension(rpc_server))
.layer(middleware::from_fn(validate_api_token)), .layer(middleware::from_fn(validate_api_token)),
) )
@ -152,12 +151,8 @@ struct CreateUserParams {
} }
async fn get_rpc_server_snapshot( async fn get_rpc_server_snapshot(
Extension(rpc_server): Extension<Option<Arc<rpc::Server>>>, Extension(rpc_server): Extension<Arc<rpc::Server>>,
) -> Result<ErasedJson> { ) -> Result<ErasedJson> {
let Some(rpc_server) = rpc_server else {
return Err(Error::Internal(anyhow!("rpc server is not available")));
};
Ok(ErasedJson::pretty(rpc_server.snapshot().await)) Ok(ErasedJson::pretty(rpc_server.snapshot().await))
} }

View file

@ -3,6 +3,7 @@ pub mod auth;
pub mod db; pub mod db;
pub mod env; pub mod env;
pub mod executor; pub mod executor;
pub mod llm;
mod rate_limiter; mod rate_limiter;
pub mod rpc; pub mod rpc;
pub mod seed; pub mod seed;
@ -124,7 +125,7 @@ impl std::fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
#[derive(Deserialize)] #[derive(Clone, Deserialize)]
pub struct Config { pub struct Config {
pub http_port: u16, pub http_port: u16,
pub database_url: String, pub database_url: String,
@ -176,6 +177,29 @@ impl Config {
} }
} }
/// The service mode that collab should run in.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ServiceMode {
Api,
Collab,
Llm,
All,
}
impl ServiceMode {
pub fn is_collab(&self) -> bool {
matches!(self, Self::Collab | Self::All)
}
pub fn is_api(&self) -> bool {
matches!(self, Self::Api | Self::All)
}
pub fn is_llm(&self) -> bool {
matches!(self, Self::Llm | Self::All)
}
}
pub struct AppState { pub struct AppState {
pub db: Arc<Database>, pub db: Arc<Database>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>, pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,

16
crates/collab/src/llm.rs Normal file
View file

@ -0,0 +1,16 @@
use std::sync::Arc;
use crate::{executor::Executor, Config, Result};
pub struct LlmState {
pub config: Config,
pub executor: Executor,
}
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let this = Self { config, executor };
Ok(Arc::new(this))
}
}

View file

@ -5,7 +5,7 @@ use axum::{
routing::get, routing::get,
Extension, Router, Extension, Router,
}; };
use collab::api::billing::poll_stripe_events_periodically; use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
use collab::{ use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
rpc::ResultExt, AppState, Config, RateLimiter, Result, rpc::ResultExt, AppState, Config, RateLimiter, Result,
@ -56,88 +56,99 @@ async fn main() -> Result<()> {
collab::seed::seed(&config, &db, true).await?; collab::seed::seed(&config, &db, true).await?;
} }
Some("serve") => { Some("serve") => {
let (is_api, is_collab) = if let Some(next) = args.next() { let mode = match args.next().as_deref() {
(next == "api", next == "collab") Some("collab") => ServiceMode::Collab,
} else { Some("api") => ServiceMode::Api,
(true, true) Some("llm") => ServiceMode::Llm,
Some("all") => ServiceMode::All,
_ => {
return Err(anyhow!(
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
))?;
}
}; };
if !is_api && !is_collab {
Err(anyhow!(
"usage: collab <version | migrate | seed | serve [api|collab]>"
))?;
}
let config = envy::from_env::<Config>().expect("error loading config"); let config = envy::from_env::<Config>().expect("error loading config");
init_tracing(&config); init_tracing(&config);
let mut app = Router::new()
.route("/", get(handle_root))
.route("/healthz", get(handle_liveness_probe))
.layer(Extension(mode));
run_migrations(&config).await?; let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
let state = AppState::new(config, Executor::Production).await?;
let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
.expect("failed to bind TCP listener"); .expect("failed to bind TCP listener");
let rpc_server = if is_collab { let mut on_shutdown = None;
let epoch = state
.db
.create_server(&state.config.zed_environment)
.await?;
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
Some(rpc_server) if mode.is_llm() {
} else { let state = LlmState::new(config.clone(), Executor::Production).await?;
None
};
if is_collab { app = app.layer(Extension(state.clone()));
state.db.purge_old_embeddings().await.trace_err();
RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
} }
if is_api { if mode.is_collab() || mode.is_api() {
poll_stripe_events_periodically(state.clone()); run_migrations(&config).await?;
fetch_extensions_from_blob_store_periodically(state.clone());
}
let mut app = collab::api::routes(rpc_server.clone(), state.clone()); let state = AppState::new(config, Executor::Production).await?;
if let Some(rpc_server) = rpc_server.clone() {
app = app.merge(collab::rpc::routes(rpc_server)) if mode.is_collab() {
} state.db.purge_old_embeddings().await.trace_err();
app = app RateLimiter::save_periodically(
.merge( state.rate_limiter.clone(),
Router::new() state.executor.clone(),
.route("/", get(handle_root)) );
.route("/healthz", get(handle_liveness_probe))
.merge(collab::api::extensions::router()) let epoch = state
.db
.create_server(&state.config.zed_environment)
.await?;
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
app = app
.merge(collab::api::routes(rpc_server.clone()))
.merge(collab::rpc::routes(rpc_server.clone()));
on_shutdown = Some(Box::new(move || rpc_server.teardown()));
}
if mode.is_api() {
poll_stripe_events_periodically(state.clone());
fetch_extensions_from_blob_store_periodically(state.clone());
app = app
.merge(collab::api::events::router()) .merge(collab::api::events::router())
.layer(Extension(state.clone())), .merge(collab::api::extensions::router())
) }
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str);
tracing::info_span!( app = app.layer(Extension(state.clone()));
"http_request", }
method = ?request.method(),
matched_path, app = app.layer(
) TraceLayer::new_for_http()
}) .make_span_with(|request: &Request<_>| {
.on_response( let matched_path = request
|response: &Response<_>, latency: Duration, _: &tracing::Span| { .extensions()
let duration_ms = latency.as_micros() as f64 / 1000.; .get::<MatchedPath>()
tracing::info!( .map(MatchedPath::as_str);
duration_ms,
status = response.status().as_u16(), tracing::info_span!(
"finished processing request" "http_request",
); method = ?request.method(),
}, matched_path,
), )
); })
.on_response(
|response: &Response<_>, latency: Duration, _: &tracing::Span| {
let duration_ms = latency.as_micros() as f64 / 1000.;
tracing::info!(
duration_ms,
status = response.status().as_u16(),
"finished processing request"
);
},
),
);
#[cfg(unix)] #[cfg(unix)]
let signal = async move { let signal = async move {
@ -174,8 +185,8 @@ async fn main() -> Result<()> {
signal.await; signal.await;
tracing::info!("Received interrupt signal"); tracing::info!("Received interrupt signal");
if let Some(rpc_server) = rpc_server { if let Some(on_shutdown) = on_shutdown {
rpc_server.teardown(); on_shutdown();
} }
}) })
.await .await
@ -183,7 +194,7 @@ async fn main() -> Result<()> {
} }
_ => { _ => {
Err(anyhow!( Err(anyhow!(
"usage: collab <version | migrate | seed | serve [api|collab]>" "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
))?; ))?;
} }
} }
@ -222,12 +233,23 @@ async fn run_migrations(config: &Config) -> Result<()> {
return Ok(()); return Ok(());
} }
async fn handle_root() -> String { async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown")) format!(
"collab {mode:?} v{VERSION} ({})",
REVISION.unwrap_or("unknown")
)
} }
async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> { async fn handle_liveness_probe(
state.db.get_all_users(0, 1).await?; app_state: Option<Extension<Arc<AppState>>>,
llm_state: Option<Extension<Arc<LlmState>>>,
) -> Result<String> {
if let Some(state) = app_state {
state.db.get_all_users(0, 1).await?;
}
if let Some(_llm_state) = llm_state {}
Ok("ok".to_string()) Ok("ok".to_string())
} }