diff --git a/Cargo.lock b/Cargo.lock index 7310ecc8582..c7d110eafb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -945,24 +945,24 @@ checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" [[package]] name = "lsp-server" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b52dccdf3302eefab8c8a1273047f0a3c3dca4b527c8458d00c09484c8371928" +version = "0.7.6" dependencies = [ "crossbeam-channel", + "ctrlc", "log", + "lsp-types", "serde", "serde_json", ] [[package]] name = "lsp-server" -version = "0.7.5" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248f65b78f6db5d8e1b1604b4098a28b43d21a8eb1deeca22b1c421b276c7095" dependencies = [ "crossbeam-channel", - "ctrlc", "log", - "lsp-types", "serde", "serde_json", ] @@ -1526,7 +1526,7 @@ dependencies = [ "ide-ssr", "itertools", "load-cargo", - "lsp-server 0.7.4", + "lsp-server 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)", "lsp-types", "mbe", "mimalloc", diff --git a/Cargo.toml b/Cargo.toml index d4cff420bcb..e82a14d16e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,7 +88,7 @@ test-utils = { path = "./crates/test-utils" } # In-tree crates that are published separately and follow semver. See lib/README.md line-index = { version = "0.1.1" } la-arena = { version = "0.3.1" } -lsp-server = { version = "0.7.4" } +lsp-server = { version = "0.7.6" } # non-local crates anyhow = "1.0.75" diff --git a/crates/rust-analyzer/src/bin/main.rs b/crates/rust-analyzer/src/bin/main.rs index 8472e49de98..6f40a4c88ed 100644 --- a/crates/rust-analyzer/src/bin/main.rs +++ b/crates/rust-analyzer/src/bin/main.rs @@ -172,7 +172,15 @@ fn run_server() -> anyhow::Result<()> { let (connection, io_threads) = Connection::stdio(); - let (initialize_id, initialize_params) = connection.initialize_start()?; + let (initialize_id, initialize_params) = match connection.initialize_start() { + Ok(it) => it, + Err(e) => { + if e.channel_is_disconnected() { + io_threads.join()?; + } + return Err(e.into()); + } + }; tracing::info!("InitializeParams: {}", initialize_params); let lsp_types::InitializeParams { root_uri, @@ -240,7 +248,12 @@ fn run_server() -> anyhow::Result<()> { let initialize_result = serde_json::to_value(initialize_result).unwrap(); - connection.initialize_finish(initialize_id, initialize_result)?; + if let Err(e) = connection.initialize_finish(initialize_id, initialize_result) { + if e.channel_is_disconnected() { + io_threads.join()?; + } + return Err(e.into()); + } if !config.has_linked_projects() && config.detached_files().is_empty() { config.rediscover_workspaces(); diff --git a/lib/lsp-server/Cargo.toml b/lib/lsp-server/Cargo.toml index e802bf185b3..116b376b0b0 100644 --- a/lib/lsp-server/Cargo.toml +++ b/lib/lsp-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "lsp-server" -version = "0.7.5" +version = "0.7.6" description = "Generic LSP server scaffold." license = "MIT OR Apache-2.0" repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/lsp-server" @@ -10,7 +10,7 @@ edition = "2021" log = "0.4.17" serde_json = "1.0.108" serde = { version = "1.0.192", features = ["derive"] } -crossbeam-channel = "0.5.6" +crossbeam-channel = "0.5.8" [dev-dependencies] lsp-types = "=0.95" diff --git a/lib/lsp-server/examples/goto_def.rs b/lib/lsp-server/examples/goto_def.rs index 2f270afbbf1..71f66254069 100644 --- a/lib/lsp-server/examples/goto_def.rs +++ b/lib/lsp-server/examples/goto_def.rs @@ -64,7 +64,15 @@ fn main() -> Result<(), Box> { ..Default::default() }) .unwrap(); - let initialization_params = connection.initialize(server_capabilities)?; + let initialization_params = match connection.initialize(server_capabilities) { + Ok(it) => it, + Err(e) => { + if e.channel_is_disconnected() { + io_threads.join()?; + } + return Err(e.into()); + } + }; main_loop(connection, initialization_params)?; io_threads.join()?; diff --git a/lib/lsp-server/src/error.rs b/lib/lsp-server/src/error.rs index 755b3fd9596..ebdd153b5b3 100644 --- a/lib/lsp-server/src/error.rs +++ b/lib/lsp-server/src/error.rs @@ -3,7 +3,22 @@ use std::fmt; use crate::{Notification, Request}; #[derive(Debug, Clone, PartialEq)] -pub struct ProtocolError(pub(crate) String); +pub struct ProtocolError(String, bool); + +impl ProtocolError { + pub(crate) fn new(msg: impl Into) -> Self { + ProtocolError(msg.into(), false) + } + + pub(crate) fn disconnected() -> ProtocolError { + ProtocolError("disconnected channel".into(), true) + } + + /// Whether this error occured due to a disconnected channel. + pub fn channel_is_disconnected(&self) -> bool { + self.1 + } +} impl std::error::Error for ProtocolError {} diff --git a/lib/lsp-server/src/lib.rs b/lib/lsp-server/src/lib.rs index 2797a6b60de..6b732d47029 100644 --- a/lib/lsp-server/src/lib.rs +++ b/lib/lsp-server/src/lib.rs @@ -17,7 +17,7 @@ use std::{ net::{TcpListener, TcpStream, ToSocketAddrs}, }; -use crossbeam_channel::{Receiver, RecvTimeoutError, Sender}; +use crossbeam_channel::{Receiver, RecvError, RecvTimeoutError, Sender}; pub use crate::{ error::{ExtractError, ProtocolError}, @@ -158,11 +158,7 @@ impl Connection { Err(RecvTimeoutError::Timeout) => { continue; } - Err(e) => { - return Err(ProtocolError(format!( - "expected initialize request, got error: {e}" - ))) - } + Err(RecvTimeoutError::Disconnected) => return Err(ProtocolError::disconnected()), }; match msg { @@ -181,12 +177,14 @@ impl Connection { continue; } msg => { - return Err(ProtocolError(format!("expected initialize request, got {msg:?}"))); + return Err(ProtocolError::new(format!( + "expected initialize request, got {msg:?}" + ))); } }; } - return Err(ProtocolError(String::from( + return Err(ProtocolError::new(String::from( "Initialization has been aborted during initialization", ))); } @@ -201,12 +199,10 @@ impl Connection { self.sender.send(resp.into()).unwrap(); match &self.receiver.recv() { Ok(Message::Notification(n)) if n.is_initialized() => Ok(()), - Ok(msg) => { - Err(ProtocolError(format!(r#"expected initialized notification, got: {msg:?}"#))) - } - Err(e) => { - Err(ProtocolError(format!("expected initialized notification, got error: {e}",))) - } + Ok(msg) => Err(ProtocolError::new(format!( + r#"expected initialized notification, got: {msg:?}"# + ))), + Err(RecvError) => Err(ProtocolError::disconnected()), } } @@ -231,10 +227,8 @@ impl Connection { Err(RecvTimeoutError::Timeout) => { continue; } - Err(e) => { - return Err(ProtocolError(format!( - "expected initialized notification, got error: {e}", - ))); + Err(RecvTimeoutError::Disconnected) => { + return Err(ProtocolError::disconnected()); } }; @@ -243,14 +237,14 @@ impl Connection { return Ok(()); } msg => { - return Err(ProtocolError(format!( + return Err(ProtocolError::new(format!( r#"expected initialized notification, got: {msg:?}"# ))); } } } - return Err(ProtocolError(String::from( + return Err(ProtocolError::new(String::from( "Initialization has been aborted during initialization", ))); } @@ -359,9 +353,18 @@ impl Connection { match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) { Ok(Message::Notification(n)) if n.is_exit() => (), Ok(msg) => { - return Err(ProtocolError(format!("unexpected message during shutdown: {msg:?}"))) + return Err(ProtocolError::new(format!( + "unexpected message during shutdown: {msg:?}" + ))) + } + Err(RecvTimeoutError::Timeout) => { + return Err(ProtocolError::new(format!("timed out waiting for exit notification"))) + } + Err(RecvTimeoutError::Disconnected) => { + return Err(ProtocolError::new(format!( + "channel disconnected waiting for exit notification" + ))) } - Err(e) => return Err(ProtocolError(format!("unexpected error during shutdown: {e}"))), } Ok(true) } @@ -426,7 +429,7 @@ mod tests { initialize_start_test(TestCase { test_messages: vec![notification_msg.clone()], - expected_resp: Err(ProtocolError(format!( + expected_resp: Err(ProtocolError::new(format!( "expected initialize request, got {:?}", notification_msg ))), diff --git a/lib/lsp-server/src/msg.rs b/lib/lsp-server/src/msg.rs index 730ad51f424..ba318dd1690 100644 --- a/lib/lsp-server/src/msg.rs +++ b/lib/lsp-server/src/msg.rs @@ -264,12 +264,12 @@ fn read_msg_text(inp: &mut dyn BufRead) -> io::Result> { let mut parts = buf.splitn(2, ": "); let header_name = parts.next().unwrap(); let header_value = - parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; + parts.next().ok_or_else(|| invalid_data(format!("malformed header: {:?}", buf)))?; if header_name.eq_ignore_ascii_case("Content-Length") { size = Some(header_value.parse::().map_err(invalid_data)?); } } - let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; + let size: usize = size.ok_or_else(|| invalid_data("no Content-Length".to_string()))?; let mut buf = buf.into_bytes(); buf.resize(size, 0); inp.read_exact(&mut buf)?; diff --git a/lib/lsp-server/src/stdio.rs b/lib/lsp-server/src/stdio.rs index e487b9b4622..cea199d0293 100644 --- a/lib/lsp-server/src/stdio.rs +++ b/lib/lsp-server/src/stdio.rs @@ -15,8 +15,7 @@ pub(crate) fn stdio_transport() -> (Sender, Receiver, IoThread let writer = thread::spawn(move || { let stdout = stdout(); let mut stdout = stdout.lock(); - writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))?; - Ok(()) + writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout)) }); let (reader_sender, reader_receiver) = bounded::(0); let reader = thread::spawn(move || {