Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/cocoindex/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class EmbedText(op.FunctionSpec):
output_dimension: int | None = None
task_type: str | None = None
api_config: llm.VertexAiConfig | None = None
api_key: str | None = None


class ExtractByLlm(op.FunctionSpec):
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/functions/_engine_builtin_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class EmbedText(op.FunctionSpec):
output_dimension: int | None = None
task_type: str | None = None
api_config: llm.VertexAiConfig | None = None
api_key: str | None = None


class ExtractByLlm(op.FunctionSpec):
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ class LlmSpec:
api_type: LlmApiType
model: str
address: str | None = None
api_key: str | None = None
api_config: VertexAiConfig | OpenAiConfig | None = None
13 changes: 9 additions & 4 deletions src/llm/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ pub struct Client {
}

impl Client {
pub async fn new(address: Option<String>) -> Result<Self> {
pub async fn new(address: Option<String>, api_key: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Anthropic doesn't support custom API address");
}
let api_key = match std::env::var("ANTHROPIC_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"),

let api_key = if let Some(key) = api_key {
key
} else {
std::env::var("ANTHROPIC_API_KEY").map_err(|_| {
anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set")
})?
};

Ok(Self {
api_key,
client: reqwest::Client::new(),
Expand Down
13 changes: 9 additions & 4 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ pub struct AiStudioClient {
}

impl AiStudioClient {
pub fn new(address: Option<String>) -> Result<Self> {
pub fn new(address: Option<String>, api_key: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Gemini doesn't support custom API address");
}
let api_key = match std::env::var("GEMINI_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"),

let api_key = if let Some(key) = api_key {
key
} else {
std::env::var("GEMINI_API_KEY")
.map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))?
};

Ok(Self {
api_key,
client: reqwest::Client::new(),
Expand Down Expand Up @@ -271,6 +275,7 @@ static SHARED_RETRY_THROTTLER: LazyLock<SharedRetryThrottler> =
impl VertexAiClient {
pub async fn new(
address: Option<String>,
_api_key: Option<String>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since api_key is not supported here, let's validate the value and raise an error.

similar to this

api_config: Option<super::LlmApiConfig>,
) -> Result<Self> {
if address.is_some() {
Expand Down
9 changes: 7 additions & 2 deletions src/llm/litellm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_litellm(address: Option<String>) -> anyhow::Result<Self> {
pub async fn new_litellm(
address: Option<String>,
api_key: Option<String>,
) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
let api_key = std::env::var("LITELLM_API_KEY").ok();

let api_key = api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok());

let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
Expand Down
48 changes: 26 additions & 22 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub struct LlmSpec {
pub api_type: LlmApiType,
pub address: Option<String>,
pub model: String,
pub api_key: Option<String>,
pub api_config: Option<LlmApiConfig>,
}

Expand Down Expand Up @@ -119,61 +120,64 @@ mod voyage;
pub async fn new_llm_generation_client(
api_type: LlmApiType,
address: Option<String>,
api_key: Option<String>,
api_config: Option<LlmApiConfig>,
) -> Result<Box<dyn LlmGenerationClient>> {
let client = match api_type {
LlmApiType::Ollama => {
Box::new(ollama::Client::new(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::OpenAi => {
Box::new(openai::Client::new(address, api_config)?) as Box<dyn LlmGenerationClient>
}
LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?)
as Box<dyn LlmGenerationClient>,
LlmApiType::Gemini => {
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmGenerationClient>
Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box<dyn LlmGenerationClient>
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
as Box<dyn LlmGenerationClient>,
LlmApiType::Anthropic => {
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
LlmApiType::VertexAi => {
Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?)
as Box<dyn LlmGenerationClient>
}
LlmApiType::Anthropic => Box::new(anthropic::Client::new(address, api_key).await?)
as Box<dyn LlmGenerationClient>,
LlmApiType::Bedrock => {
Box::new(bedrock::Client::new(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::LiteLlm => {
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(address).await?)
LlmApiType::LiteLlm => Box::new(litellm::Client::new_litellm(address, api_key).await?)
as Box<dyn LlmGenerationClient>,
LlmApiType::OpenRouter => {
Box::new(openrouter::Client::new_openrouter(address, api_key).await?)
as Box<dyn LlmGenerationClient>
}
LlmApiType::Voyage => {
api_bail!("Voyage is not supported for generation")
}
LlmApiType::Vllm => {
Box::new(vllm::Client::new_vllm(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_key).await?)
as Box<dyn LlmGenerationClient>,
};
Ok(client)
}

pub async fn new_llm_embedding_client(
api_type: LlmApiType,
address: Option<String>,
api_key: Option<String>,
api_config: Option<LlmApiConfig>,
) -> Result<Box<dyn LlmEmbeddingClient>> {
let client = match api_type {
LlmApiType::Ollama => {
Box::new(ollama::Client::new(address).await?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::Gemini => {
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::OpenAi => {
Box::new(openai::Client::new(address, api_config)?) as Box<dyn LlmEmbeddingClient>
Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?)
as Box<dyn LlmEmbeddingClient>,
LlmApiType::Voyage => {
Box::new(voyage::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
Box::new(voyage::Client::new(address, api_key)?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::VertexAi => {
Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?)
as Box<dyn LlmEmbeddingClient>
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
as Box<dyn LlmEmbeddingClient>,
LlmApiType::OpenRouter
| LlmApiType::LiteLlm
| LlmApiType::Vllm
Expand Down
19 changes: 13 additions & 6 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ impl Client {
Self { client }
}

pub fn new(address: Option<String>, api_config: Option<super::LlmApiConfig>) -> Result<Self> {
pub fn new(
address: Option<String>,
api_key: Option<String>,
api_config: Option<super::LlmApiConfig>,
) -> Result<Self> {
let config = match api_config {
Some(super::LlmApiConfig::OpenAi(config)) => config,
Some(_) => api_bail!("unexpected config type, expected OpenAiConfig"),
Expand All @@ -49,13 +53,16 @@ impl Client {
if let Some(project_id) = config.project_id {
openai_config = openai_config.with_project_id(project_id);
}

// Verify API key is set
if std::env::var("OPENAI_API_KEY").is_err() {
api_bail!("OPENAI_API_KEY environment variable must be set");
if let Some(key) = api_key {
openai_config = openai_config.with_api_key(key);
} else {
// Verify API key is set in environment if not provided in config
if std::env::var("OPENAI_API_KEY").is_err() {
api_bail!("OPENAI_API_KEY environment variable must be set");
}
}

Ok(Self {
// OpenAI client will use OPENAI_API_KEY and OPENAI_API_BASE env variables by default
client: OpenAIClient::with_config(openai_config),
})
}
Expand Down
9 changes: 7 additions & 2 deletions src/llm/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_openrouter(address: Option<String>) -> anyhow::Result<Self> {
pub async fn new_openrouter(
address: Option<String>,
api_key: Option<String>,
) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
let api_key = std::env::var("OPENROUTER_API_KEY").ok();

let api_key = api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok());

let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
Expand Down
9 changes: 7 additions & 2 deletions src/llm/vllm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_vllm(address: Option<String>) -> anyhow::Result<Self> {
pub async fn new_vllm(
address: Option<String>,
api_key: Option<String>,
) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string());
let api_key = std::env::var("VLLM_API_KEY").ok();

let api_key = api_key.or_else(|| std::env::var("VLLM_API_KEY").ok());

let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
Expand Down
12 changes: 8 additions & 4 deletions src/llm/voyage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,18 @@ pub struct Client {
}

impl Client {
pub fn new(address: Option<String>) -> Result<Self> {
pub fn new(address: Option<String>, api_key: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Voyage AI doesn't support custom API address");
}
let api_key = match std::env::var("VOYAGE_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("VOYAGE_API_KEY environment variable must be set"),

let api_key = if let Some(key) = api_key {
key
} else {
std::env::var("VOYAGE_API_KEY")
.map_err(|_| anyhow::anyhow!("VOYAGE_API_KEY environment variable must be set"))?
};

Ok(Self {
api_key,
client: reqwest::Client::new(),
Expand Down
13 changes: 10 additions & 3 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct Spec {
api_config: Option<LlmApiConfig>,
output_dimension: Option<u32>,
task_type: Option<String>,
api_key: Option<String>,
}

struct Args {
Expand Down Expand Up @@ -91,9 +92,14 @@ impl SimpleFunctionFactoryBase for Factory {
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?
.required()?;
let client =
new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone())
.await?;

let client = new_llm_embedding_client(
spec.api_type,
spec.address.clone(),
spec.api_key.clone(),
spec.api_config.clone(),
)
.await?;
let output_dimension = match spec.output_dimension {
Some(output_dimension) => output_dimension,
None => {
Expand Down Expand Up @@ -144,6 +150,7 @@ mod tests {
api_config: None,
output_dimension: None,
task_type: None,
api_key: None,
};

let factory = Arc::new(Factory);
Expand Down
3 changes: 3 additions & 0 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl Executor {
let client = new_llm_generation_client(
spec.llm_spec.api_type,
spec.llm_spec.address,
spec.llm_spec.api_key,
spec.llm_spec.api_config,
)
.await?;
Expand Down Expand Up @@ -204,6 +205,7 @@ mod tests {
api_type: crate::llm::LlmApiType::OpenAi,
model: "gpt-4o".to_string(),
address: None,
api_key: None,
api_config: None,
},
output_type: output_type_spec,
Expand Down Expand Up @@ -274,6 +276,7 @@ mod tests {
api_type: crate::llm::LlmApiType::OpenAi,
model: "gpt-4o".to_string(),
address: None,
api_key: None,
api_config: None,
},
output_type: make_output_type(BasicValueType::Str),
Expand Down