|
14 | 14 | ConversationalAgent, |
15 | 15 | ConversationalChatAgent, |
16 | 16 | ) |
17 | | -from langchain.chat_models import ChatAnthropic, ChatOpenAI |
| 17 | +from langchain.chat_models import AzureChatOpenAI, ChatAnthropic, ChatOpenAI |
18 | 18 | from langchain.chat_models.base import BaseChatModel |
19 | 19 | from langchain.memory import ConversationBufferMemory |
20 | 20 | from langchain.prompts.chat import MessagesPlaceholder |
@@ -87,13 +87,33 @@ def _choose_llm( |
87 | 87 | "OpenAI API key missing. Set OPENAI_API_KEY env variable " |
88 | 88 | "or pass `openai_api_key` to session." |
89 | 89 | ) |
90 | | - return ChatOpenAI( |
91 | | - temperature=0.03, |
92 | | - model=model, |
93 | | - openai_api_key=openai_api_key, |
94 | | - max_retries=3, |
95 | | - request_timeout=60 * 3, |
96 | | - ) # type: ignore |
| 90 | + openai_api_version = getenv("OPENAI_API_VERSION") |
| 91 | + openai_api_base = getenv("OPENAI_API_BASE") |
| 92 | + deployment_name = getenv("DEPLOYMENT_NAME") |
| 93 | + openapi_type = getenv("OPENAI_API_TYPE") |
| 94 | + if ( |
| 95 | + openapi_type == "azure" |
| 96 | + and openai_api_version |
| 97 | + and openai_api_base |
| 98 | + and deployment_name |
| 99 | + ): |
| 100 | + return AzureChatOpenAI( |
| 101 | + temperature=0.03, |
| 102 | + openai_api_base=openai_api_base, |
| 103 | + openai_api_version=openai_api_version, |
| 104 | + deployment_name=deployment_name, |
| 105 | + openai_api_key=openai_api_key, |
| 106 | + max_retries=3, |
| 107 | + request_timeout=60 * 3, |
| 108 | + ) |
| 109 | + else: |
| 110 | + return ChatOpenAI( |
| 111 | + temperature=0.03, |
| 112 | + model=model, |
| 113 | + openai_api_key=openai_api_key, |
| 114 | + max_retries=3, |
| 115 | + request_timeout=60 * 3, |
| 116 | + ) |
97 | 117 | elif "claude" in model: |
98 | 118 | return ChatAnthropic(model=model) |
99 | 119 | else: |
|
0 commit comments