1
Fork 0

Auto merge of #16226 - Veykril:lsp-server, r=Veykril

internal: Expose whether a channel has been dropped in lsp-server errors

Not the best way to expose this, but this should allow us to give somewhat better errors when the initialization request is malformed, as currently that just results in a channel disconnected error instead of the deserialization error. cc https://github.com/rust-lang/rust-analyzer/issues/15859
This commit is contained in:
bors 2024-01-01 13:13:38 +00:00
commit a8d935eedc
9 changed files with 79 additions and 41 deletions

14
Cargo.lock generated
View file

@ -945,24 +945,24 @@ checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]] [[package]]
name = "lsp-server" name = "lsp-server"
version = "0.7.4" version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b52dccdf3302eefab8c8a1273047f0a3c3dca4b527c8458d00c09484c8371928"
dependencies = [ dependencies = [
"crossbeam-channel", "crossbeam-channel",
"ctrlc",
"log", "log",
"lsp-types",
"serde", "serde",
"serde_json", "serde_json",
] ]
[[package]] [[package]]
name = "lsp-server" name = "lsp-server"
version = "0.7.5" version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "248f65b78f6db5d8e1b1604b4098a28b43d21a8eb1deeca22b1c421b276c7095"
dependencies = [ dependencies = [
"crossbeam-channel", "crossbeam-channel",
"ctrlc",
"log", "log",
"lsp-types",
"serde", "serde",
"serde_json", "serde_json",
] ]
@ -1526,7 +1526,7 @@ dependencies = [
"ide-ssr", "ide-ssr",
"itertools", "itertools",
"load-cargo", "load-cargo",
"lsp-server 0.7.4", "lsp-server 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)",
"lsp-types", "lsp-types",
"mbe", "mbe",
"mimalloc", "mimalloc",

View file

@ -88,7 +88,7 @@ test-utils = { path = "./crates/test-utils" }
# In-tree crates that are published separately and follow semver. See lib/README.md # In-tree crates that are published separately and follow semver. See lib/README.md
line-index = { version = "0.1.1" } line-index = { version = "0.1.1" }
la-arena = { version = "0.3.1" } la-arena = { version = "0.3.1" }
lsp-server = { version = "0.7.4" } lsp-server = { version = "0.7.6" }
# non-local crates # non-local crates
anyhow = "1.0.75" anyhow = "1.0.75"

View file

@ -172,7 +172,15 @@ fn run_server() -> anyhow::Result<()> {
let (connection, io_threads) = Connection::stdio(); let (connection, io_threads) = Connection::stdio();
let (initialize_id, initialize_params) = connection.initialize_start()?; let (initialize_id, initialize_params) = match connection.initialize_start() {
Ok(it) => it,
Err(e) => {
if e.channel_is_disconnected() {
io_threads.join()?;
}
return Err(e.into());
}
};
tracing::info!("InitializeParams: {}", initialize_params); tracing::info!("InitializeParams: {}", initialize_params);
let lsp_types::InitializeParams { let lsp_types::InitializeParams {
root_uri, root_uri,
@ -240,7 +248,12 @@ fn run_server() -> anyhow::Result<()> {
let initialize_result = serde_json::to_value(initialize_result).unwrap(); let initialize_result = serde_json::to_value(initialize_result).unwrap();
connection.initialize_finish(initialize_id, initialize_result)?; if let Err(e) = connection.initialize_finish(initialize_id, initialize_result) {
if e.channel_is_disconnected() {
io_threads.join()?;
}
return Err(e.into());
}
if !config.has_linked_projects() && config.detached_files().is_empty() { if !config.has_linked_projects() && config.detached_files().is_empty() {
config.rediscover_workspaces(); config.rediscover_workspaces();

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lsp-server" name = "lsp-server"
version = "0.7.5" version = "0.7.6"
description = "Generic LSP server scaffold." description = "Generic LSP server scaffold."
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/lsp-server" repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/lsp-server"
@ -10,7 +10,7 @@ edition = "2021"
log = "0.4.17" log = "0.4.17"
serde_json = "1.0.108" serde_json = "1.0.108"
serde = { version = "1.0.192", features = ["derive"] } serde = { version = "1.0.192", features = ["derive"] }
crossbeam-channel = "0.5.6" crossbeam-channel = "0.5.8"
[dev-dependencies] [dev-dependencies]
lsp-types = "=0.95" lsp-types = "=0.95"

View file

@ -64,7 +64,15 @@ fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
..Default::default() ..Default::default()
}) })
.unwrap(); .unwrap();
let initialization_params = connection.initialize(server_capabilities)?; let initialization_params = match connection.initialize(server_capabilities) {
Ok(it) => it,
Err(e) => {
if e.channel_is_disconnected() {
io_threads.join()?;
}
return Err(e.into());
}
};
main_loop(connection, initialization_params)?; main_loop(connection, initialization_params)?;
io_threads.join()?; io_threads.join()?;

View file

@ -3,7 +3,22 @@ use std::fmt;
use crate::{Notification, Request}; use crate::{Notification, Request};
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct ProtocolError(pub(crate) String); pub struct ProtocolError(String, bool);
impl ProtocolError {
pub(crate) fn new(msg: impl Into<String>) -> Self {
ProtocolError(msg.into(), false)
}
pub(crate) fn disconnected() -> ProtocolError {
ProtocolError("disconnected channel".into(), true)
}
/// Whether this error occured due to a disconnected channel.
pub fn channel_is_disconnected(&self) -> bool {
self.1
}
}
impl std::error::Error for ProtocolError {} impl std::error::Error for ProtocolError {}

View file

@ -17,7 +17,7 @@ use std::{
net::{TcpListener, TcpStream, ToSocketAddrs}, net::{TcpListener, TcpStream, ToSocketAddrs},
}; };
use crossbeam_channel::{Receiver, RecvTimeoutError, Sender}; use crossbeam_channel::{Receiver, RecvError, RecvTimeoutError, Sender};
pub use crate::{ pub use crate::{
error::{ExtractError, ProtocolError}, error::{ExtractError, ProtocolError},
@ -158,11 +158,7 @@ impl Connection {
Err(RecvTimeoutError::Timeout) => { Err(RecvTimeoutError::Timeout) => {
continue; continue;
} }
Err(e) => { Err(RecvTimeoutError::Disconnected) => return Err(ProtocolError::disconnected()),
return Err(ProtocolError(format!(
"expected initialize request, got error: {e}"
)))
}
}; };
match msg { match msg {
@ -181,12 +177,14 @@ impl Connection {
continue; continue;
} }
msg => { msg => {
return Err(ProtocolError(format!("expected initialize request, got {msg:?}"))); return Err(ProtocolError::new(format!(
"expected initialize request, got {msg:?}"
)));
} }
}; };
} }
return Err(ProtocolError(String::from( return Err(ProtocolError::new(String::from(
"Initialization has been aborted during initialization", "Initialization has been aborted during initialization",
))); )));
} }
@ -201,12 +199,10 @@ impl Connection {
self.sender.send(resp.into()).unwrap(); self.sender.send(resp.into()).unwrap();
match &self.receiver.recv() { match &self.receiver.recv() {
Ok(Message::Notification(n)) if n.is_initialized() => Ok(()), Ok(Message::Notification(n)) if n.is_initialized() => Ok(()),
Ok(msg) => { Ok(msg) => Err(ProtocolError::new(format!(
Err(ProtocolError(format!(r#"expected initialized notification, got: {msg:?}"#))) r#"expected initialized notification, got: {msg:?}"#
} ))),
Err(e) => { Err(RecvError) => Err(ProtocolError::disconnected()),
Err(ProtocolError(format!("expected initialized notification, got error: {e}",)))
}
} }
} }
@ -231,10 +227,8 @@ impl Connection {
Err(RecvTimeoutError::Timeout) => { Err(RecvTimeoutError::Timeout) => {
continue; continue;
} }
Err(e) => { Err(RecvTimeoutError::Disconnected) => {
return Err(ProtocolError(format!( return Err(ProtocolError::disconnected());
"expected initialized notification, got error: {e}",
)));
} }
}; };
@ -243,14 +237,14 @@ impl Connection {
return Ok(()); return Ok(());
} }
msg => { msg => {
return Err(ProtocolError(format!( return Err(ProtocolError::new(format!(
r#"expected initialized notification, got: {msg:?}"# r#"expected initialized notification, got: {msg:?}"#
))); )));
} }
} }
} }
return Err(ProtocolError(String::from( return Err(ProtocolError::new(String::from(
"Initialization has been aborted during initialization", "Initialization has been aborted during initialization",
))); )));
} }
@ -359,9 +353,18 @@ impl Connection {
match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) { match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) {
Ok(Message::Notification(n)) if n.is_exit() => (), Ok(Message::Notification(n)) if n.is_exit() => (),
Ok(msg) => { Ok(msg) => {
return Err(ProtocolError(format!("unexpected message during shutdown: {msg:?}"))) return Err(ProtocolError::new(format!(
"unexpected message during shutdown: {msg:?}"
)))
}
Err(RecvTimeoutError::Timeout) => {
return Err(ProtocolError::new(format!("timed out waiting for exit notification")))
}
Err(RecvTimeoutError::Disconnected) => {
return Err(ProtocolError::new(format!(
"channel disconnected waiting for exit notification"
)))
} }
Err(e) => return Err(ProtocolError(format!("unexpected error during shutdown: {e}"))),
} }
Ok(true) Ok(true)
} }
@ -426,7 +429,7 @@ mod tests {
initialize_start_test(TestCase { initialize_start_test(TestCase {
test_messages: vec![notification_msg.clone()], test_messages: vec![notification_msg.clone()],
expected_resp: Err(ProtocolError(format!( expected_resp: Err(ProtocolError::new(format!(
"expected initialize request, got {:?}", "expected initialize request, got {:?}",
notification_msg notification_msg
))), ))),

View file

@ -264,12 +264,12 @@ fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> {
let mut parts = buf.splitn(2, ": "); let mut parts = buf.splitn(2, ": ");
let header_name = parts.next().unwrap(); let header_name = parts.next().unwrap();
let header_value = let header_value =
parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; parts.next().ok_or_else(|| invalid_data(format!("malformed header: {:?}", buf)))?;
if header_name.eq_ignore_ascii_case("Content-Length") { if header_name.eq_ignore_ascii_case("Content-Length") {
size = Some(header_value.parse::<usize>().map_err(invalid_data)?); size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
} }
} }
let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; let size: usize = size.ok_or_else(|| invalid_data("no Content-Length".to_string()))?;
let mut buf = buf.into_bytes(); let mut buf = buf.into_bytes();
buf.resize(size, 0); buf.resize(size, 0);
inp.read_exact(&mut buf)?; inp.read_exact(&mut buf)?;

View file

@ -15,8 +15,7 @@ pub(crate) fn stdio_transport() -> (Sender<Message>, Receiver<Message>, IoThread
let writer = thread::spawn(move || { let writer = thread::spawn(move || {
let stdout = stdout(); let stdout = stdout();
let mut stdout = stdout.lock(); let mut stdout = stdout.lock();
writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))?; writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))
Ok(())
}); });
let (reader_sender, reader_receiver) = bounded::<Message>(0); let (reader_sender, reader_receiver) = bounded::<Message>(0);
let reader = thread::spawn(move || { let reader = thread::spawn(move || {