diff --git a/Cargo.lock b/Cargo.lock index 0044279ed..fd12e85ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" @@ -216,6 +225,21 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "backtrace" +version = "0.3.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.21.7" @@ -311,11 +335,11 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.83" +version = "1.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "58e804ac3194a48bb129643eb1d62fcc20d18c6b8c181704489353d13120bcd1" dependencies = [ - "libc", + "shlex", ] [[package]] @@ -816,6 +840,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + [[package]] name = "glob" version = "0.3.1" @@ -1250,6 +1280,15 @@ dependencies = [ "libc", ] +[[package]] +name = "object" +version = "0.36.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -1447,6 +1486,8 @@ dependencies = [ "sqlx", "text-size", "threadpool", + "tokio", + "tokio-util", ] [[package]] @@ -1879,6 +1920,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -2452,6 +2499,41 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokio" +version = "1.40.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +dependencies = [ + "backtrace", + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.40" diff --git a/crates/pg_lsp/Cargo.toml b/crates/pg_lsp/Cargo.toml index 122e3ccdf..9d23bce9b 100644 --- a/crates/pg_lsp/Cargo.toml +++ b/crates/pg_lsp/Cargo.toml @@ -32,6 +32,8 @@ pg_base_db.workspace = true pg_schema_cache.workspace = true pg_workspace.workspace = true pg_diagnostics.workspace = true +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "sync"] } +tokio-util = "0.7.12" [dev-dependencies] diff --git a/crates/pg_lsp/src/client.rs b/crates/pg_lsp/src/client.rs index 85ff5a61e..851769425 100644 --- a/crates/pg_lsp/src/client.rs +++ b/crates/pg_lsp/src/client.rs @@ -50,6 +50,14 @@ impl LspClient { Ok(()) } + /// This will ignore any errors that occur while sending the notification. + pub fn send_info_notification(&self, message: &str) { + let _ = self.send_notification::(ShowMessageParams { + message: message.into(), + typ: MessageType::INFO, + }); + } + pub fn send_request(&self, params: R::Params) -> Result where R: lsp_types::request::Request, diff --git a/crates/pg_lsp/src/client/client_flags.rs b/crates/pg_lsp/src/client/client_flags.rs index 8fca812d3..6209f443c 100644 --- a/crates/pg_lsp/src/client/client_flags.rs +++ b/crates/pg_lsp/src/client/client_flags.rs @@ -1,10 +1,36 @@ +use lsp_types::InitializeParams; + /// Contains information about the client's capabilities. /// This is used to determine which features the server can use. #[derive(Debug, Clone)] pub struct ClientFlags { - /// If `true`, the server can pull the configuration from the client. - pub configuration_pull: bool, + /// If `true`, the server can pull configuration from the client. + pub has_configuration: bool, + + /// If `true`, the client notifies the server when its configuration changes. + pub will_push_configuration: bool, +} + +impl ClientFlags { + pub(crate) fn from_initialize_request_params(params: &InitializeParams) -> Self { + let has_configuration = params + .capabilities + .workspace + .as_ref() + .and_then(|w| w.configuration) + .unwrap_or(false); + + let will_push_configuration = params + .capabilities + .workspace + .as_ref() + .and_then(|w| w.did_change_configuration) + .and_then(|c| c.dynamic_registration) + .unwrap_or(false); - /// If `true`, the client notifies the server when the configuration changes. - pub configuration_push: bool, + Self { + has_configuration, + will_push_configuration, + } + } } diff --git a/crates/pg_lsp/src/db_connection.rs b/crates/pg_lsp/src/db_connection.rs new file mode 100644 index 000000000..51ba633dd --- /dev/null +++ b/crates/pg_lsp/src/db_connection.rs @@ -0,0 +1,66 @@ +use pg_schema_cache::SchemaCache; +use sqlx::{postgres::PgListener, PgPool}; +use tokio::task::JoinHandle; + +#[derive(Debug)] +pub(crate) struct DbConnection { + pub pool: PgPool, + connection_string: String, + schema_update_handle: Option>, +} + +impl DbConnection { + pub(crate) async fn new(connection_string: String) -> Result { + let pool = PgPool::connect(&connection_string).await?; + Ok(Self { + pool, + connection_string: connection_string, + schema_update_handle: None, + }) + } + + pub(crate) fn connected_to(&self, connection_string: &str) -> bool { + connection_string == self.connection_string + } + + pub(crate) async fn close(self) { + if self.schema_update_handle.is_some() { + self.schema_update_handle.unwrap().abort(); + } + self.pool.close().await; + } + + pub(crate) async fn listen_for_schema_updates( + &mut self, + on_schema_update: F, + ) -> anyhow::Result<()> + where + F: Fn(SchemaCache) -> () + Send + 'static, + { + let mut listener = PgListener::connect_with(&self.pool).await?; + listener.listen_all(["postgres_lsp", "pgrst"]).await?; + + let pool = self.pool.clone(); + + let handle: JoinHandle<()> = tokio::spawn(async move { + loop { + match listener.recv().await { + Ok(not) => { + if not.payload().to_string() == "reload schema" { + let schema_cache = SchemaCache::load(&pool).await; + on_schema_update(schema_cache); + }; + } + Err(why) => { + eprintln!("Error receiving notification: {:?}", why); + break; + } + } + } + }); + + self.schema_update_handle = Some(handle); + + Ok(()) + } +} diff --git a/crates/pg_lsp/src/lib.rs b/crates/pg_lsp/src/lib.rs index ac95c9133..97474d524 100644 --- a/crates/pg_lsp/src/lib.rs +++ b/crates/pg_lsp/src/lib.rs @@ -1,3 +1,4 @@ mod client; +mod db_connection; pub mod server; mod utils; diff --git a/crates/pg_lsp/src/main.rs b/crates/pg_lsp/src/main.rs index eb5eddb62..9c678fac5 100644 --- a/crates/pg_lsp/src/main.rs +++ b/crates/pg_lsp/src/main.rs @@ -1,9 +1,12 @@ use lsp_server::Connection; use pg_lsp::server::Server; -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { let (connection, threads) = Connection::stdio(); - Server::init(connection)?; + + let server = Server::init(connection)?; + server.run().await?; threads.join()?; Ok(()) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 927d7f168..65d2a3543 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -2,8 +2,6 @@ mod debouncer; mod dispatch; pub mod options; -use async_std::task::{self}; -use crossbeam_channel::{unbounded, Receiver, Sender}; use lsp_server::{Connection, ErrorCode, Message, RequestId}; use lsp_types::{ notification::{ @@ -29,51 +27,69 @@ use pg_hover::HoverParams; use pg_schema_cache::SchemaCache; use pg_workspace::Workspace; use serde::{de::DeserializeOwned, Serialize}; -use std::{collections::HashSet, sync::Arc, time::Duration}; +use std::{collections::HashSet, future::Future, sync::Arc, time::Duration}; use text_size::TextSize; -use threadpool::ThreadPool; + +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; use crate::{ client::{client_flags::ClientFlags, LspClient}, + db_connection::DbConnection, utils::{file_path, from_proto, line_index_ext::LineIndexExt, normalize_uri, to_proto}, }; use self::{debouncer::EventDebouncer, options::Options}; -use sqlx::{ - postgres::{PgListener, PgPool}, - Executor, -}; +use sqlx::{postgres::PgPool, Executor}; #[derive(Debug)] enum InternalMessage { PublishDiagnostics(lsp_types::Url), SetOptions(Options), - RefreshSchemaCache, SetSchemaCache(SchemaCache), + SetDatabaseConnection(DbConnection), } -#[derive(Debug)] -struct DbConnection { - pub pool: PgPool, - connection_string: String, -} +/// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime. +/// For now, we move it into a separate task and use tokio's channels to communicate. +fn get_client_receiver( + connection: Connection, + cancel_token: Arc, +) -> mpsc::UnboundedReceiver { + let (message_tx, message_rx) = mpsc::unbounded_channel(); -impl DbConnection { - pub async fn new(connection_string: &str) -> Result { - let pool = PgPool::connect(connection_string).await?; - Ok(Self { - pool, - connection_string: connection_string.to_owned(), - }) - } + tokio::task::spawn(async move { + loop { + let msg = match connection.receiver.recv() { + Ok(msg) => msg, + Err(e) => { + eprint!("Connection was closed by LSP client: {}", e); + cancel_token.cancel(); + return; + } + }; + + match msg { + Message::Request(r) if connection.handle_shutdown(&r).unwrap() => { + cancel_token.cancel(); + return; + } + + // any non-shutdown request is forwarded to the server + _ => message_tx.send(msg).unwrap(), + }; + } + }); + + message_rx } pub struct Server { - connection: Arc, + client_rx: mpsc::UnboundedReceiver, + cancel_token: Arc, client: LspClient, - internal_tx: Sender, - internal_rx: Receiver, - pool: Arc, + internal_tx: mpsc::UnboundedSender, + internal_rx: mpsc::UnboundedReceiver, client_flags: Arc, ide: Arc, db_conn: Option, @@ -81,41 +97,29 @@ pub struct Server { } impl Server { - pub fn init(connection: Connection) -> anyhow::Result<()> { + pub fn init(connection: Connection) -> anyhow::Result { let client = LspClient::new(connection.sender.clone()); + let cancel_token = Arc::new(CancellationToken::new()); - let (internal_tx, internal_rx) = unbounded(); + let (client_flags, client_rx) = Self::establish_client_connection(connection, &cancel_token)?; - let (id, params) = connection.initialize_start()?; - let params: InitializeParams = serde_json::from_value(params)?; - - let result = InitializeResult { - capabilities: Self::capabilities(), - server_info: Some(ServerInfo { - name: "Postgres LSP".to_owned(), - version: Some(env!("CARGO_PKG_VERSION").to_owned()), - }), - }; - - connection.initialize_finish(id, serde_json::to_value(result)?)?; - - let client_flags = Arc::new(from_proto::client_flags(params.capabilities)); - - let pool = Arc::new(threadpool::Builder::new().build()); let ide = Arc::new(Workspace::new()); + let (internal_tx, internal_rx) = mpsc::unbounded_channel(); + let cloned_tx = internal_tx.clone(); let cloned_ide = ide.clone(); - let cloned_pool = pool.clone(); + let pool = Arc::new(threadpool::Builder::new().build()); let cloned_client = client.clone(); let server = Self { - connection: Arc::new(connection), + cancel_token, + client_rx, internal_rx, internal_tx, client, - client_flags, + client_flags: Arc::new(client_flags), db_conn: None, ide, compute_debouncer: EventDebouncer::new( @@ -124,7 +128,7 @@ impl Server { let inner_cloned_ide = cloned_ide.clone(); let inner_cloned_tx = cloned_tx.clone(); let inner_cloned_client = cloned_client.clone(); - cloned_pool.execute(move || { + pool.execute(move || { inner_cloned_client .send_notification::(ShowMessageParams { typ: lsp_types::MessageType::INFO, @@ -155,11 +159,9 @@ impl Server { }); }, ), - pool, }; - server.run()?; - Ok(()) + Ok(server) } fn compute_now(&self) { @@ -170,7 +172,7 @@ impl Server { self.compute_debouncer.clear(); - self.pool.execute(move || { + self.spawn_with_cancel(async move { client .send_notification::(ShowMessageParams { typ: lsp_types::MessageType::INFO, @@ -209,71 +211,58 @@ impl Server { }); } - fn start_listening(&self) { - if self.db_conn.is_none() { - return; + fn update_db_connection(&self, options: Options) -> anyhow::Result<()> { + if options.db_connection_string.is_none() + || self + .db_conn + .as_ref() + // if the connection is already connected to the same database, do nothing + .is_some_and(|c| c.connected_to(options.db_connection_string.as_ref().unwrap())) + { + return Ok(()); } - let pool = self.db_conn.as_ref().unwrap().pool.clone(); - let tx = self.internal_tx.clone(); + let connection_string = options.db_connection_string.unwrap(); - task::spawn(async move { - let mut listener = PgListener::connect_with(&pool).await.unwrap(); - listener - .listen_all(["postgres_lsp", "pgrst"]) - .await - .unwrap(); - - loop { - match listener.recv().await { - Ok(notification) => { - if notification.payload().to_string() == "reload schema" { - tx.send(InternalMessage::RefreshSchemaCache).unwrap(); - } - } - Err(e) => { - eprintln!("Listener error: {}", e); - break; - } + let internal_tx = self.internal_tx.clone(); + let client = self.client.clone(); + self.spawn_with_cancel(async move { + match DbConnection::new(connection_string.into()).await { + Ok(conn) => { + internal_tx + .send(InternalMessage::SetDatabaseConnection(conn)) + .unwrap(); + } + Err(why) => { + client.send_info_notification(&format!( + "Unable to update database connection: {}", + why + )); } } }); - } - - async fn update_db_connection(&mut self, connection_string: Option) { - if connection_string == self.db_conn.as_ref().map(|c| c.connection_string.clone()) { - return; - } - if let Some(conn) = self.db_conn.take() { - conn.pool.close().await; - } - - if connection_string.is_none() { - return; - } - let new_conn = DbConnection::new(connection_string.unwrap().as_str()).await; + Ok(()) + } - if new_conn.is_err() { - return; + async fn listen_for_schema_updates(&mut self) -> anyhow::Result<()> { + if self.db_conn.is_none() { + eprintln!("Error trying to listen for schema updates: No database connection"); + return Ok(()); } - self.db_conn = Some(new_conn.unwrap()); - - self.client - .send_notification::(ShowMessageParams { - typ: lsp_types::MessageType::INFO, - message: "Connection to database established".to_string(), + let internal_tx = self.internal_tx.clone(); + self.db_conn + .as_mut() + .unwrap() + .listen_for_schema_updates(move |schema_cache| { + internal_tx + .send(InternalMessage::SetSchemaCache(schema_cache)) + .expect("LSP Server: Failed to send internal message."); }) - .unwrap(); - - self.refresh_schema_cache(); + .await?; - self.start_listening(); - } - - fn update_options(&mut self, options: Options) { - async_std::task::block_on(self.update_db_connection(options.db_connection_string)); + Ok(()) } fn capabilities() -> ServerCapabilities { @@ -687,15 +676,17 @@ impl Server { Q: FnOnce() -> anyhow::Result + Send + 'static, { let client = self.client.clone(); - self.pool.execute(move || match query() { - Ok(result) => { - let response = lsp_server::Response::new_ok(id, result); - client.send_response(response).unwrap(); - } - Err(why) => { - client - .send_error(id, ErrorCode::InternalError, why.to_string()) - .unwrap(); + self.spawn_with_cancel(async move { + match query() { + Ok(result) => { + let response = lsp_server::Response::new_ok(id, result); + client.send_response(response).unwrap(); + } + Err(why) => { + client + .send_error(id, ErrorCode::InternalError, why.to_string()) + .unwrap(); + } } }); } @@ -721,31 +712,11 @@ impl Server { let client = self.client.clone(); let ide = Arc::clone(&self.ide); - self.pool.execute(move || { + self.spawn_with_cancel(async move { let response = lsp_server::Response::new_ok(id, query(&ide)); - client.send_response(response).unwrap(); - }); - } - - fn refresh_schema_cache(&self) { - if self.db_conn.is_none() { - return; - } - - let tx = self.internal_tx.clone(); - let conn = self.db_conn.as_ref().unwrap().pool.clone(); - let client = self.client.clone(); - - async_std::task::spawn(async move { client - .send_notification::(ShowMessageParams { - typ: lsp_types::MessageType::INFO, - message: "Refreshing schema cache...".to_string(), - }) - .unwrap(); - let schema_cache = SchemaCache::load(&conn).await; - tx.send(InternalMessage::SetSchemaCache(schema_cache)) - .unwrap(); + .send_response(response) + .expect("Failed to send query to client"); }); } @@ -753,84 +724,103 @@ impl Server { &mut self, params: DidChangeConfigurationParams, ) -> anyhow::Result<()> { - if self.client_flags.configuration_pull { + if self.client_flags.has_configuration { self.pull_options(); } else { let options = self.client.parse_options(params.settings)?; - self.update_options(options); + self.update_db_connection(options)?; } Ok(()) } - fn process_messages(&mut self) -> anyhow::Result<()> { + async fn process_messages(&mut self) -> anyhow::Result<()> { loop { - crossbeam_channel::select! { - recv(&self.connection.receiver) -> msg => { - match msg? { - Message::Request(request) => { - if self.connection.handle_shutdown(&request)? { - return Ok(()); - } - - if let Some(response) = dispatch::RequestDispatcher::new(request) - .on::(|id, params| self.inlay_hint(id, params))? - .on::(|id, params| self.hover(id, params))? - .on::(|id, params| self.execute_command(id, params))? - .on::(|id, params| { - self.completion(id, params) - })? - .on::(|id, params| { - self.code_actions(id, params) - })? - .default() - { - self.client.send_response(response)?; - } - } - Message::Notification(notification) => { - dispatch::NotificationDispatcher::new(notification) - .on::(|params| { - self.did_change_configuration(params) - })? - .on::(|params| self.did_close(params))? - .on::(|params| self.did_open(params))? - .on::(|params| self.did_change(params))? - .on::(|params| self.did_save(params))? - .on::(|params| self.did_close(params))? - .default(); - } - Message::Response(response) => { - self.client.recv_response(response)?; - } - }; + tokio::select! { + _ = self.cancel_token.cancelled() => { + // Close the loop, proceed to shutdown. + return Ok(()) }, - recv(&self.internal_rx) -> msg => { - match msg? { - InternalMessage::SetSchemaCache(c) => { - self.ide.set_schema_cache(c); - self.compute_now(); - } - InternalMessage::RefreshSchemaCache => { - self.refresh_schema_cache(); - } - InternalMessage::PublishDiagnostics(uri) => { - self.publish_diagnostics(uri)?; - } - InternalMessage::SetOptions(options) => { - self.update_options(options); - } - }; + + msg = self.internal_rx.recv() => { + match msg { + None => panic!("The LSP's internal sender closed. This should never happen."), + Some(m) => self.handle_internal_message(m).await + } + }, + + msg = self.client_rx.recv() => { + match msg { + None => panic!("The LSP's client closed, but not via an 'exit' method. This should never happen."), + Some(m) => self.handle_message(m).await + } + }, + }?; + } + } + + async fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> { + match msg { + Message::Request(request) => { + if let Some(response) = dispatch::RequestDispatcher::new(request) + .on::(|id, params| self.inlay_hint(id, params))? + .on::(|id, params| self.hover(id, params))? + .on::(|id, params| self.execute_command(id, params))? + .on::(|id, params| self.completion(id, params))? + .on::(|id, params| self.code_actions(id, params))? + .default() + { + self.client.send_response(response)?; } - }; + } + Message::Notification(notification) => { + dispatch::NotificationDispatcher::new(notification) + .on::(|params| { + self.did_change_configuration(params) + })? + .on::(|params| self.did_close(params))? + .on::(|params| self.did_open(params))? + .on::(|params| self.did_change(params))? + .on::(|params| self.did_save(params))? + .on::(|params| self.did_close(params))? + .default(); + } + Message::Response(response) => { + self.client.recv_response(response)?; + } } + + Ok(()) } - fn pull_options(&mut self) { - if !self.client_flags.configuration_pull { - return; + async fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> { + match msg { + InternalMessage::SetSchemaCache(c) => { + self.client + .send_info_notification("Refreshing Schema Cache..."); + self.ide.set_schema_cache(c); + self.client.send_info_notification("Updated Schema Cache."); + self.compute_now(); + } + InternalMessage::PublishDiagnostics(uri) => { + self.publish_diagnostics(uri)?; + } + InternalMessage::SetOptions(options) => { + self.update_db_connection(options)?; + } + InternalMessage::SetDatabaseConnection(conn) => { + let current = self.db_conn.replace(conn); + if current.is_some() { + current.unwrap().close().await + } + self.listen_for_schema_updates().await?; + } } + Ok(()) + } + + fn pull_options(&mut self) { let params = ConfigurationParams { items: vec![ConfigurationItem { section: Some("postgres_lsp".to_string()), @@ -839,53 +829,97 @@ impl Server { }; let client = self.client.clone(); - let sender = self.internal_tx.clone(); - self.pool.execute(move || { + let internal_tx = self.internal_tx.clone(); + self.spawn_with_cancel(async move { match client.send_request::(params) { Ok(mut json) => { let options = client .parse_options(json.pop().expect("invalid configuration request")) .unwrap(); - sender.send(InternalMessage::SetOptions(options)).unwrap(); + if let Err(why) = internal_tx.send(InternalMessage::SetOptions(options)) { + println!("Failed to set internal options: {}", why); + } } - Err(_why) => { - // log::error!("Retrieving configuration failed: {}", why); + Err(why) => { + println!("Retrieving configuration failed: {}", why); } }; }); } fn register_configuration(&mut self) { - if self.client_flags.configuration_push { - let registration = Registration { - id: "pull-config".to_string(), - method: DidChangeConfiguration::METHOD.to_string(), - register_options: None, - }; + let registration = Registration { + id: "pull-config".to_string(), + method: DidChangeConfiguration::METHOD.to_string(), + register_options: None, + }; - let params = RegistrationParams { - registrations: vec![registration], - }; + let params = RegistrationParams { + registrations: vec![registration], + }; - let client = self.client.clone(); - self.pool.execute(move || { - if let Err(_why) = client.send_request::(params) { - // log::error!( - // "Failed to register \"{}\" notification: {}", - // DidChangeConfiguration::METHOD, - // why - // ); - } - }); - } + let client = self.client.clone(); + self.spawn_with_cancel(async move { + if let Err(why) = client.send_request::(params) { + println!( + "Failed to register \"{}\" notification: {}", + DidChangeConfiguration::METHOD, + why + ); + } + }); } - pub fn run(mut self) -> anyhow::Result<()> { - self.register_configuration(); - self.pull_options(); - self.process_messages()?; - self.pool.join(); - Ok(()) + fn establish_client_connection( + connection: Connection, + cancel_token: &Arc, + ) -> anyhow::Result<(ClientFlags, mpsc::UnboundedReceiver)> { + let (id, params) = connection.initialize_start()?; + + let params: InitializeParams = serde_json::from_value(params)?; + + let result = InitializeResult { + capabilities: Self::capabilities(), + server_info: Some(ServerInfo { + name: "Postgres LSP".to_owned(), + version: Some(env!("CARGO_PKG_VERSION").to_owned()), + }), + }; + + connection.initialize_finish(id, serde_json::to_value(result)?)?; + + let client_rx = get_client_receiver(connection, cancel_token.clone()); + + let client_flags = ClientFlags::from_initialize_request_params(¶ms); + + Ok((client_flags, client_rx)) + } + + /// Spawns an asynchronous task that can be cancelled with the `Server`'s `cancel_token`. + fn spawn_with_cancel(&self, f: F) -> tokio::task::JoinHandle> + where + F: Future + Send + 'static, + O: Send + 'static, + { + let cancel_token = self.cancel_token.clone(); + tokio::spawn(async move { + tokio::select! { + _ = cancel_token.cancelled() => None, + output = f => Some(output) + } + }) + } + + pub async fn run(mut self) -> anyhow::Result<()> { + if self.client_flags.will_push_configuration { + self.register_configuration(); + } + + if self.client_flags.has_configuration { + self.pull_options(); + } + + self.process_messages().await } } diff --git a/crates/pg_lsp/src/server/debouncer/thread.rs b/crates/pg_lsp/src/server/debouncer/thread.rs index 1aa85939c..a7486f216 100644 --- a/crates/pg_lsp/src/server/debouncer/thread.rs +++ b/crates/pg_lsp/src/server/debouncer/thread.rs @@ -8,7 +8,6 @@ use super::buffer::{EventBuffer, Get, State}; struct DebouncerThread { mutex: Arc>, thread: JoinHandle<()>, - stopped: Arc, } impl DebouncerThread { @@ -33,16 +32,7 @@ impl DebouncerThread { } } }); - Self { - mutex, - thread, - stopped, - } - } - - fn stop(self) -> JoinHandle<()> { - self.stopped.store(true, Ordering::Relaxed); - self.thread + Self { mutex, thread } } } @@ -68,13 +58,6 @@ impl EventDebouncer { pub fn clear(&self) { self.0.mutex.lock().unwrap().clear(); } - - /// Signals the debouncer thread to quit and returns a - /// [std::thread::JoinHandle] which can be `.join()`ed in the consumer - /// thread. The common idiom is: `debouncer.stop().join().unwrap();` - pub fn stop(self) -> JoinHandle<()> { - self.0.stop() - } } #[cfg(test)] diff --git a/crates/pg_lsp/src/utils/from_proto.rs b/crates/pg_lsp/src/utils/from_proto.rs index 47708be71..eaae06ceb 100644 --- a/crates/pg_lsp/src/utils/from_proto.rs +++ b/crates/pg_lsp/src/utils/from_proto.rs @@ -1,5 +1,3 @@ -use crate::client::client_flags::ClientFlags; - use super::line_index_ext::LineIndexExt; use pg_base_db::{Change, Document}; @@ -17,23 +15,3 @@ pub fn content_changes( }) .collect() } - -pub fn client_flags(capabilities: lsp_types::ClientCapabilities) -> ClientFlags { - let configuration_pull = capabilities - .workspace - .as_ref() - .and_then(|cap| cap.configuration) - .unwrap_or(false); - - let configuration_push = capabilities - .workspace - .as_ref() - .and_then(|cap| cap.did_change_configuration) - .and_then(|cap| cap.dynamic_registration) - .unwrap_or(false); - - ClientFlags { - configuration_pull, - configuration_push, - } -} diff --git a/xtask/src/install.rs b/xtask/src/install.rs index 85c03e13d..c149bd5a3 100644 --- a/xtask/src/install.rs +++ b/xtask/src/install.rs @@ -137,10 +137,7 @@ fn install_client(sh: &Shell, client_opt: ClientOpt) -> anyhow::Result<()> { } fn install_server(sh: &Shell) -> anyhow::Result<()> { - let cmd = cmd!( - sh, - "cargo install --path crates/pg_lsp --locked --force" - ); + let cmd = cmd!(sh, "cargo install --path crates/pg_lsp --locked --force"); cmd.run()?; Ok(()) }