mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-11 21:13:02 +00:00
Implement Copilot sign in and sign out
This commit is contained in:
parent
797bb7d780
commit
59d9277a74
4 changed files with 223 additions and 48 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1340,6 +1340,7 @@ dependencies = [
|
|||
"client",
|
||||
"futures 0.3.25",
|
||||
"gpui",
|
||||
"log",
|
||||
"lsp",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
|
|
|
@ -17,6 +17,7 @@ client = { path = "../client" }
|
|||
workspace = { path = "../workspace" }
|
||||
async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] }
|
||||
anyhow = "1.0"
|
||||
log = "0.4"
|
||||
serde = { workspace = true }
|
||||
serde_derive = { workspace = true }
|
||||
smol = "1.2.5"
|
||||
|
|
|
@ -3,7 +3,7 @@ mod request;
|
|||
use anyhow::{anyhow, Result};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use client::Client;
|
||||
use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext};
|
||||
use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
|
||||
use lsp::LanguageServer;
|
||||
use smol::{fs, io::BufReader, stream::StreamExt};
|
||||
use std::{
|
||||
|
@ -15,11 +15,32 @@ use util::{
|
|||
fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt,
|
||||
};
|
||||
|
||||
actions!(copilot, [SignIn]);
|
||||
actions!(copilot, [SignIn, SignOut]);
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut MutableAppContext) {
|
||||
let copilot = cx.add_model(|cx| Copilot::start(client.http_client(), cx));
|
||||
let (copilot, task) = Copilot::start(client.http_client(), cx);
|
||||
cx.set_global(copilot);
|
||||
cx.spawn(|mut cx| async move {
|
||||
task.await?;
|
||||
cx.update(|cx| {
|
||||
cx.add_global_action(|_: &SignIn, cx: &mut MutableAppContext| {
|
||||
if let Some(copilot) = Copilot::global(cx) {
|
||||
copilot
|
||||
.update(cx, |copilot, cx| copilot.sign_in(cx))
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
});
|
||||
cx.add_global_action(|_: &SignOut, cx: &mut MutableAppContext| {
|
||||
if let Some(copilot) = Copilot::global(cx) {
|
||||
copilot
|
||||
.update(cx, |copilot, cx| copilot.sign_out(cx))
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
});
|
||||
});
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
enum CopilotServer {
|
||||
|
@ -31,18 +52,26 @@ enum CopilotServer {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
enum SignInStatus {
|
||||
Authorized,
|
||||
Unauthorized,
|
||||
Authorized { user: String },
|
||||
Unauthorized { user: String },
|
||||
SignedOut,
|
||||
}
|
||||
|
||||
pub enum Event {
|
||||
PromptUserDeviceFlow {
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
},
|
||||
}
|
||||
|
||||
struct Copilot {
|
||||
server: CopilotServer,
|
||||
}
|
||||
|
||||
impl Entity for Copilot {
|
||||
type Event = ();
|
||||
type Event = Event;
|
||||
}
|
||||
|
||||
impl Copilot {
|
||||
|
@ -54,46 +83,123 @@ impl Copilot {
|
|||
}
|
||||
}
|
||||
|
||||
fn start(http: Arc<dyn HttpClient>, cx: &mut ModelContext<Self>) -> Self {
|
||||
let copilot = Self {
|
||||
fn start(
|
||||
http: Arc<dyn HttpClient>,
|
||||
cx: &mut MutableAppContext,
|
||||
) -> (ModelHandle<Self>, Task<Result<()>>) {
|
||||
let this = cx.add_model(|_| Self {
|
||||
server: CopilotServer::Downloading,
|
||||
};
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let start_language_server = async {
|
||||
let server_path = get_lsp_binary(http).await?;
|
||||
let server =
|
||||
LanguageServer::new(0, &server_path, &["--stdio"], Path::new("/"), cx.clone())?;
|
||||
let server = server.initialize(Default::default()).await?;
|
||||
let status = server
|
||||
.request::<request::CheckStatus>(request::CheckStatusParams {
|
||||
local_checks_only: false,
|
||||
})
|
||||
.await?;
|
||||
let status = match status.status.as_str() {
|
||||
"OK" | "MaybeOk" => SignInStatus::Authorized,
|
||||
"NotAuthorized" => SignInStatus::Unauthorized,
|
||||
_ => SignInStatus::SignedOut,
|
||||
});
|
||||
let task = cx.spawn({
|
||||
let this = this.clone();
|
||||
|mut cx| async move {
|
||||
let start_language_server = async {
|
||||
let server_path = get_lsp_binary(http).await?;
|
||||
let server = LanguageServer::new(
|
||||
0,
|
||||
&server_path,
|
||||
&["--stdio"],
|
||||
Path::new("/"),
|
||||
cx.clone(),
|
||||
)?;
|
||||
let server = server.initialize(Default::default()).await?;
|
||||
let status = server
|
||||
.request::<request::CheckStatus>(request::CheckStatusParams {
|
||||
local_checks_only: false,
|
||||
})
|
||||
.await?;
|
||||
anyhow::Ok((server, status))
|
||||
};
|
||||
anyhow::Ok((server, status))
|
||||
};
|
||||
|
||||
let server = start_language_server.await;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
cx.notify();
|
||||
match server {
|
||||
Ok((server, status)) => {
|
||||
this.server = CopilotServer::Started { server, status };
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
this.server = CopilotServer::Error(error.to_string());
|
||||
Err(error)
|
||||
let server = start_language_server.await;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
cx.notify();
|
||||
match server {
|
||||
Ok((server, status)) => {
|
||||
this.server = CopilotServer::Started {
|
||||
server,
|
||||
status: SignInStatus::SignedOut,
|
||||
};
|
||||
this.update_sign_in_status(status, cx);
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
this.server = CopilotServer::Error(error.to_string());
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
});
|
||||
(this, task)
|
||||
}
|
||||
|
||||
fn sign_in(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
if let CopilotServer::Started { server, .. } = &self.server {
|
||||
let server = server.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let sign_in = server
|
||||
.request::<request::SignInInitiate>(request::SignInInitiateParams {})
|
||||
.await?;
|
||||
if let request::SignInInitiateResult::PromptUserDeviceFlow(flow) = sign_in {
|
||||
this.update(&mut cx, |_, cx| {
|
||||
cx.emit(Event::PromptUserDeviceFlow {
|
||||
user_code: flow.user_code.clone(),
|
||||
verification_uri: flow.verification_uri,
|
||||
});
|
||||
});
|
||||
let response = server
|
||||
.request::<request::SignInConfirm>(request::SignInConfirmParams {
|
||||
user_code: flow.user_code,
|
||||
})
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| this.update_sign_in_status(response, cx));
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
copilot
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("copilot hasn't started yet")))
|
||||
}
|
||||
}
|
||||
|
||||
fn sign_out(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
if let CopilotServer::Started { server, .. } = &self.server {
|
||||
let server = server.clone();
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
server
|
||||
.request::<request::SignOut>(request::SignOutParams {})
|
||||
.await?;
|
||||
this.update(&mut cx, |this, cx| {
|
||||
if let CopilotServer::Started { status, .. } = &mut this.server {
|
||||
*status = SignInStatus::SignedOut;
|
||||
cx.notify();
|
||||
}
|
||||
});
|
||||
|
||||
anyhow::Ok(())
|
||||
})
|
||||
} else {
|
||||
Task::ready(Err(anyhow!("copilot hasn't started yet")))
|
||||
}
|
||||
}
|
||||
|
||||
fn update_sign_in_status(
|
||||
&mut self,
|
||||
lsp_status: request::SignInStatus,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
if let CopilotServer::Started { status, .. } = &mut self.server {
|
||||
*status = match lsp_status {
|
||||
request::SignInStatus::Ok { user } | request::SignInStatus::MaybeOk { user } => {
|
||||
SignInStatus::Authorized { user }
|
||||
}
|
||||
request::SignInStatus::NotAuthorized { user } => {
|
||||
SignInStatus::Unauthorized { user }
|
||||
}
|
||||
_ => SignInStatus::SignedOut,
|
||||
};
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,15 +8,82 @@ pub struct CheckStatusParams {
|
|||
pub local_checks_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CheckStatusResult {
|
||||
pub status: String,
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
impl lsp::request::Request for CheckStatus {
|
||||
type Params = CheckStatusParams;
|
||||
type Result = CheckStatusResult;
|
||||
type Result = SignInStatus;
|
||||
const METHOD: &'static str = "checkStatus";
|
||||
}
|
||||
|
||||
pub enum SignInInitiate {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SignInInitiateParams {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "status")]
|
||||
pub enum SignInInitiateResult {
|
||||
AlreadySignedIn { user: String },
|
||||
PromptUserDeviceFlow(PromptUserDeviceFlow),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptUserDeviceFlow {
|
||||
pub user_code: String,
|
||||
pub verification_uri: String,
|
||||
}
|
||||
|
||||
impl lsp::request::Request for SignInInitiate {
|
||||
type Params = SignInInitiateParams;
|
||||
type Result = SignInInitiateResult;
|
||||
const METHOD: &'static str = "signInInitiate";
|
||||
}
|
||||
|
||||
pub enum SignInConfirm {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SignInConfirmParams {
|
||||
pub user_code: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "status")]
|
||||
pub enum SignInStatus {
|
||||
#[serde(rename = "OK")]
|
||||
Ok {
|
||||
user: String,
|
||||
},
|
||||
MaybeOk {
|
||||
user: String,
|
||||
},
|
||||
AlreadySignedIn {
|
||||
user: String,
|
||||
},
|
||||
NotAuthorized {
|
||||
user: String,
|
||||
},
|
||||
NotSignedIn,
|
||||
}
|
||||
|
||||
impl lsp::request::Request for SignInConfirm {
|
||||
type Params = SignInConfirmParams;
|
||||
type Result = SignInStatus;
|
||||
const METHOD: &'static str = "signInConfirm";
|
||||
}
|
||||
|
||||
pub enum SignOut {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SignOutParams {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SignOutResult {}
|
||||
|
||||
impl lsp::request::Request for SignOut {
|
||||
type Params = SignOutParams;
|
||||
type Result = SignOutResult;
|
||||
const METHOD: &'static str = "signOut";
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue