Skip to content

Commit d701388

Browse files
committed
refactor: env to args
1 parent 76034d1 commit d701388

File tree

1 file changed

+118
-111
lines changed

1 file changed

+118
-111
lines changed

src/main.rs

Lines changed: 118 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,23 @@ mod transport; // Declare transport module
99

1010
// Use necessary items from modules and crates
1111
use crate::{
12-
doc_loader::load_documents,
13-
embeddings::{generate_embeddings, OPENAI_CLIENT},
14-
embeddings::SerializableEmbedding,
12+
doc_loader::{Document}, // Import Document struct and module
13+
embeddings::{generate_embeddings, SerializableEmbedding, OPENAI_CLIENT}, // Group imports
1514
error::ServerError,
1615
server::RustDocsServer,
17-
transport::StdioTransport, // Import StdioTransport
16+
transport::StdioTransport,
1817
};
1918
use async_openai::Client as OpenAIClient;
20-
use rmcp::{
21-
serve_server,
22-
transport::io::stdio, // Keep stdio function
23-
// Remove TransportAdapterAsyncCombinedRW import
24-
};
25-
use std::env;
19+
use bincode::config; // Keep config
2620
use ndarray::Array1;
27-
use std::fs::{self, File};
28-
use std::io::BufReader; // Removed unused BufWriter
29-
use std::path::PathBuf; // Removed unused Path
30-
use xdg::BaseDirectories;
31-
use bincode::{
32-
config,
33-
// serde::OwnedSerdeDecoder, // No longer needed
34-
// decode_from_reader, // Removed unused import
35-
// encode_to_vec, // Removed unused import
36-
// Encode, Decode, // No longer needed directly
21+
use rmcp::{serve_server, transport::io::stdio};
22+
use std::{
23+
env,
24+
fs::{self, File},
25+
io::BufReader,
26+
path::PathBuf,
3727
};
28+
use xdg::BaseDirectories;
3829

3930
#[tokio::main]
4031
async fn main() -> Result<(), ServerError> {
@@ -52,163 +43,179 @@ async fn main() -> Result<(), ServerError> {
5243
ServerError::MissingArgument("CRATE_VERSION".to_string())
5344
})?;
5445

55-
let _openai_api_key = env::var("OPENAI_API_KEY")
56-
.map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?; // Needed later
57-
58-
// Load documents by generating them dynamically
59-
println!("Loading documents for crate: {}", crate_name);
60-
let documents = load_documents(&crate_name, &crate_version)?; // Pass crate_name and crate_version
61-
println!("Loaded {} documents.", documents.len());
62-
63-
// Initialize OpenAI client and set it in the OnceLock
64-
let openai_client = OpenAIClient::new();
65-
OPENAI_CLIENT
66-
.set(openai_client.clone()) // Clone for generate_embeddings
67-
.expect("Failed to set OpenAI client");
68-
69-
// --- Persistence Logic ---
70-
// Use XDG Base Directory specification for data storage
46+
// --- Determine Paths ---
7147
let xdg_dirs = BaseDirectories::with_prefix("rustdocs-mcp-server")
72-
.map_err(|e| ServerError::Xdg(format!("Failed to get XDG directories: {}", e)))?; // Use the new Xdg variant
48+
.map_err(|e| ServerError::Xdg(format!("Failed to get XDG directories: {}", e)))?;
7349

74-
// Construct the path within the XDG data directory, including the crate name
75-
let relative_path = PathBuf::from(&crate_name).join("embeddings.bin");
50+
// Construct the path for embeddings file
51+
let embeddings_relative_path = PathBuf::from(&crate_name).join("embeddings.bin");
52+
let embeddings_file_path = xdg_dirs
53+
.place_data_file(embeddings_relative_path)
54+
.map_err(ServerError::Io)?;
7655

77-
// Use place_data_file to get the full path and ensure parent directories exist
78-
let embeddings_file_path = xdg_dirs.place_data_file(relative_path)
79-
.map_err(ServerError::Io)?; // Map IO error if directory creation fails
56+
// --- Try Loading Embeddings ---
57+
let mut loaded_from_cache = false;
58+
let mut loaded_embeddings: Option<Vec<(String, Array1<f32>)>> = None;
8059

81-
let embeddings = if embeddings_file_path.exists() {
82-
println!("Loading embeddings from: {:?}", embeddings_file_path);
60+
if embeddings_file_path.exists() {
61+
println!("Attempting to load embeddings from: {:?}", embeddings_file_path);
8362
match File::open(&embeddings_file_path) {
8463
Ok(file) => {
8564
let reader = BufReader::new(file);
86-
// Use top-level decode_from_reader now that bincode serde feature is enabled
87-
match bincode::decode_from_reader::<Vec<SerializableEmbedding>, _, _>(reader, config::standard()) {
65+
match bincode::decode_from_reader::<Vec<SerializableEmbedding>, _, _>(
66+
reader,
67+
config::standard(),
68+
) {
8869
Ok(loaded_serializable) => {
89-
println!("Successfully loaded embeddings. Converting format...");
90-
// Convert back to Vec<(String, Array1<f32>)>
91-
let converted_embeddings = loaded_serializable
70+
println!("Successfully loaded embeddings from cache. Converting format...");
71+
let converted = loaded_serializable
9272
.into_iter()
93-
.map(|se| (se.path, Array1::from(se.vector))) // Convert Vec to Array1
73+
.map(|se| (se.path, Array1::from(se.vector)))
9474
.collect::<Vec<_>>();
95-
Some(converted_embeddings) // Wrap in Option for the outer match
75+
loaded_embeddings = Some(converted);
76+
loaded_from_cache = true; // Set flag
9677
}
9778
Err(e) => {
98-
println!("Failed to decode embeddings: {}. Regenerating...", e);
99-
// Fall through to regeneration
100-
None
79+
println!(
80+
"Failed to decode embeddings file: {}. Will regenerate.",
81+
e
82+
);
83+
// Proceed to generation
10184
}
10285
}
10386
}
10487
Err(e) => {
105-
println!("Failed to open embeddings file: {}. Regenerating...", e);
106-
// Fall through to regeneration
107-
None
88+
println!(
89+
"Failed to open embeddings file: {}. Will regenerate.",
90+
e
91+
);
92+
// Proceed to generation
10893
}
10994
}
11095
} else {
111-
println!("Embeddings file not found. Generating...");
112-
None
113-
};
96+
println!("Embeddings file not found. Will generate.");
97+
// Proceed to generation
98+
}
11499

115-
// Use loaded embeddings or generate new ones if loading failed or file didn't exist
116-
// Variables to store generation stats if needed
100+
// --- Generate or Use Loaded Embeddings ---
117101
let mut generated_tokens: Option<usize> = None;
118102
let mut generation_cost: Option<f64> = None;
103+
let mut documents_for_server: Vec<Document> = Vec::new(); // Empty by default
119104

120-
let embeddings = match embeddings {
121-
Some(e) => e,
105+
let final_embeddings = match loaded_embeddings {
106+
Some(embeddings) => {
107+
println!("Using embeddings loaded from cache.");
108+
embeddings // Use the ones loaded from the file
109+
}
122110
None => {
123-
// Directory creation is handled by xdg_dirs.place_data_file
124-
125-
// Generate embeddings
111+
// --- Generation Path ---
112+
println!("Proceeding with documentation loading and embedding generation.");
113+
114+
// Ensure OpenAI API key is available ONLY if generating
115+
let _openai_api_key = env::var("OPENAI_API_KEY")
116+
.map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?;
117+
118+
// Initialize OpenAI client ONLY if generating
119+
let openai_client = OpenAIClient::new();
120+
OPENAI_CLIENT
121+
.set(openai_client.clone())
122+
.expect("Failed to set OpenAI client");
123+
124+
// 1. Load documents
125+
println!("Loading documents for crate: {}", crate_name);
126+
// Use the imported module function directly
127+
let loaded_documents = doc_loader::load_documents(&crate_name, &crate_version)?;
128+
println!("Loaded {} documents.", loaded_documents.len());
129+
documents_for_server = loaded_documents.clone(); // Clone for server if needed (though user said no)
130+
131+
// 2. Generate embeddings
126132
println!("Generating embeddings...");
127-
// Capture the returned tuple (embeddings, total_tokens)
128-
let (generated_embeddings, total_tokens) =
129-
generate_embeddings(&openai_client, &documents, "text-embedding-3-small").await?;
130-
131-
// Calculate and print cost
132-
// Price: $0.02 / 1M tokens for text-embedding-3-small
133+
let (generated_embeddings, total_tokens) = generate_embeddings(
134+
&openai_client,
135+
&loaded_documents, // Use the just-loaded documents
136+
"text-embedding-3-small",
137+
)
138+
.await?;
139+
140+
// Calculate and store cost
133141
let cost_per_million = 0.02;
134142
let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million;
135143
println!(
136-
"Embedding generation cost for {} tokens: ${:.6}", // Format for cents/fractions
144+
"Embedding generation cost for {} tokens: ${:.6}",
137145
total_tokens, estimated_cost
138146
);
139-
// Store generation stats
140147
generated_tokens = Some(total_tokens);
141148
generation_cost = Some(estimated_cost);
142149

143-
144-
println!("Embeddings generated. Saving to: {:?}", embeddings_file_path);
145-
146-
// Convert to serializable format
147-
let serializable_embeddings: Vec<SerializableEmbedding> = generated_embeddings // Use the embeddings from the tuple
150+
// 3. Save embeddings
151+
println!("Saving generated embeddings to: {:?}", embeddings_file_path);
152+
let serializable_embeddings: Vec<SerializableEmbedding> = generated_embeddings
148153
.iter()
149154
.map(|(path, array)| SerializableEmbedding {
150155
path: path.clone(),
151-
vector: array.to_vec(), // Convert Array1 to Vec
156+
vector: array.to_vec(),
152157
})
153158
.collect();
154159

155-
// Encode directly to Vec<u8>
156160
match bincode::encode_to_vec(&serializable_embeddings, config::standard()) {
157161
Ok(encoded_bytes) => {
158-
// Write the bytes to the file
159162
if let Err(e) = fs::write(&embeddings_file_path, encoded_bytes) {
160163
println!("Warning: Failed to write embeddings file: {}", e);
161164
} else {
162165
println!("Embeddings saved successfully.");
163166
}
164167
}
165168
Err(e) => {
166-
// Log error but continue
167169
println!("Warning: Failed to encode embeddings to vec: {}", e);
168170
}
169171
}
170-
generated_embeddings
172+
generated_embeddings // Return the generated embeddings
171173
}
172174
};
173-
// --- End Persistence Logic ---
174-
175175

176+
// --- Initialize and Start Server ---
176177
println!("Initializing server for crate: {}", crate_name);
177178

178-
// Create the service instance, passing embeddings
179179
// Prepare the startup summary message
180-
let startup_message = {
181-
let doc_count = documents.len();
182-
match (generated_tokens, generation_cost) {
183-
(Some(tokens), Some(cost)) => {
184-
// Embeddings were generated
185-
format!(
186-
"Server for crate '{}' initialized. Loaded {} documents. Generated embeddings for {} tokens (Est. Cost: ${:.6}).",
187-
crate_name, doc_count, tokens, cost
188-
)
189-
}
190-
_ => {
191-
// Embeddings were loaded from cache
192-
format!(
193-
"Server for crate '{}' initialized. Loaded {} documents from cache.",
194-
crate_name, doc_count
195-
)
196-
}
197-
}
180+
let startup_message = if loaded_from_cache {
181+
format!(
182+
"Server for crate '{}' initialized. Loaded {} embeddings from cache.",
183+
crate_name,
184+
final_embeddings.len() // Use count from loaded/generated embeddings
185+
)
186+
} else {
187+
// Embeddings were generated
188+
let tokens = generated_tokens.unwrap_or(0);
189+
let cost = generation_cost.unwrap_or(0.0);
190+
format!(
191+
"Server for crate '{}' initialized. Generated {} embeddings for {} tokens (Est. Cost: ${:.6}).",
192+
crate_name,
193+
final_embeddings.len(), // Use count from loaded/generated embeddings
194+
tokens,
195+
cost
196+
)
198197
};
199198

200-
// Note: We still pass 'documents' which were loaded regardless of embedding source
201-
let service = RustDocsServer::new(crate_name, documents, embeddings, startup_message)?;
199+
// Create the service instance
200+
// Pass the final embeddings and an empty Vec for documents as it's not needed by the service
201+
let service = RustDocsServer::new(
202+
crate_name,
203+
documents_for_server, // Pass the (potentially empty) documents vec
204+
final_embeddings,
205+
startup_message,
206+
)?;
202207

203208
// Create the stdio transport
204209
let (stdin, stdout) = stdio();
205-
// Use the custom StdioTransport wrapper
206-
let transport = StdioTransport { reader: stdin, writer: stdout };
210+
let transport = StdioTransport {
211+
reader: stdin,
212+
writer: stdout,
213+
};
207214

208215
println!("Rust Docs MCP server starting...");
209216

210217
// Serve the server
211-
serve_server(service, transport).await?; // Use imported serve_server
218+
serve_server(service, transport).await?;
212219

213220
println!("Rust Docs MCP server stopped.");
214221
Ok(())

0 commit comments

Comments
 (0)