Skip to content

Commit 6007021

Browse files
committed
update RemoteLanguageModel to support cohere and enhance the design
- Update the model names to be more flexible. - Update RemoteLanguageModel to support CohereLanguageResponse. - Implement unit testing cases. - Update the library version.
1 parent 3c2a6eb commit 6007021

File tree

4 files changed

+244
-103
lines changed

4 files changed

+244
-103
lines changed

core/com.intellijava.core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>io.github.barqawiz</groupId>
88
<artifactId>intellijava.core</artifactId>
9-
<version>0.5.5</version>
9+
<version>0.6.0</version>
1010

1111
<name>Intellijava</name>
1212
<description>IntelliJava allows java developers to easily integrate with the latest language models, image generation, and deep learning frameworks.</description>

core/com.intellijava.core/src/main/java/com/intellijava/core/controller/RemoteLanguageModel.java

Lines changed: 155 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,93 +16,195 @@
1616
package com.intellijava.core.controller;
1717

1818
import java.io.IOException;
19+
import java.util.ArrayList;
1920
import java.util.HashMap;
21+
import java.util.List;
2022
import java.util.Map;
21-
23+
import com.intellijava.core.model.CohereLanguageResponse;
2224
import com.intellijava.core.model.OpenaiLanguageResponse;
25+
import com.intellijava.core.model.SupportedLangModels;
2326
import com.intellijava.core.model.input.LanguageModelInput;
27+
import com.intellijava.core.wrappers.CohereAIWrapper;
2428
import com.intellijava.core.wrappers.OpenAIWrapper;
2529

2630
/**
27-
* A class to call the most sophisticated remote language models.
28-
*
29-
* This class provides an API for interacting with OpenAI's GPT-3 language model. It is designed to be easily extensible
30-
* to support other models in the future.
31+
* RemoteLanguageModel class to call the most sophisticated remote language
32+
* models.
33+
*
34+
* This class support: - Openai: - url: openai.com - description: provides an
35+
* API for interacting with OpenAI's GPT-3 language model. - model names :
36+
* text-davinci-003, text-curie-001, text-babbage-001, more.
37+
*
38+
* - cohere: - url: cohere.ai - description: provides an API for interacting
39+
* with generate language model. - it is recommended to fine tune your model or
40+
* add example of the response in the prompt when calling cohere models. - model
41+
* names : medium or xlarge
3142
*
3243
* @author github.com/Barqawiz
3344
*
3445
*/
3546
public class RemoteLanguageModel {
36-
37-
private String keyType;
47+
48+
private SupportedLangModels keyType;
3849
private OpenAIWrapper openaiWrapper;
39-
50+
private CohereAIWrapper cohereWrapper;
51+
4052
/**
41-
* Constructor for the RemoteLanguageModel class.
42-
*
43-
* Creates an instance of the class and sets up the API key and the key type.
44-
* Currently, only the "openai" key type is supported.
45-
*
46-
* @param keyValue the API key.
47-
* @param keyType support openai only.
48-
*
49-
* @throws IllegalArgumentException if the keyType passed is not "openai".
50-
*
51-
*/
52-
public RemoteLanguageModel(String keyValue, String keyType) {
53-
54-
if (keyType.isEmpty() || keyType.equals("openai")) {
55-
this.keyType = "openai";
56-
openaiWrapper = new OpenAIWrapper(keyValue);
53+
* Constructor for the RemoteLanguageModel class.
54+
*
55+
* Creates an instance of the class and sets up the key and the API type.
56+
*
57+
* @param keyValue the API key.
58+
* @param keyType either openai (default) or cohere or send empty string for
59+
* default value.
60+
*
61+
* @throws IllegalArgumentException if the keyType passed is not "openai".
62+
*
63+
*/
64+
public RemoteLanguageModel(String keyValue, String keyTypeString) {
65+
66+
if (keyTypeString.isEmpty()) {
67+
keyTypeString = SupportedLangModels.openai.toString();
68+
}
69+
70+
List<String> supportedModels = this.getSupportedModels();
71+
72+
if (supportedModels.contains(keyTypeString)) {
73+
this.initiate(keyValue, SupportedLangModels.valueOf(keyTypeString));
5774
} else {
58-
throw new IllegalArgumentException("This version support openai keyType only");
75+
String models = String.join(" - ", supportedModels);
76+
throw new IllegalArgumentException("The received keyValue not supported. Send any model from: " + models);
5977
}
6078
}
79+
80+
/**
81+
* Get the supported models names as array of string
82+
*
83+
* @return supportedModels
84+
*/
85+
public List<String> getSupportedModels() {
86+
SupportedLangModels[] values = SupportedLangModels.values();
87+
List<String> enumValues = new ArrayList<>();
88+
89+
for (int i = 0; i < values.length; i++) {
90+
enumValues.add(values[i].name());
91+
}
92+
93+
return enumValues;
94+
}
6195

62-
96+
/**
97+
* Constructor for the RemoteLanguageModel class.
98+
*
99+
* Creates an instance of the class and sets up the API key and the enum key
100+
* type.
101+
*
102+
* @param keyValue the API key.
103+
* @param keyType enum version from the key type (SupportedModels).
104+
*
105+
* @throws IllegalArgumentException if the keyType passed is not "openai".
106+
*
107+
*/
108+
public RemoteLanguageModel(String keyValue, SupportedLangModels keyType) {
109+
this.initiate(keyValue, keyType);
110+
}
111+
112+
private void initiate(String keyValue, SupportedLangModels keyType) {
113+
// set the model type
114+
this.keyType = keyType;
115+
116+
// generate the related model
117+
if (keyType.equals(SupportedLangModels.openai)) {
118+
this.openaiWrapper = new OpenAIWrapper(keyValue);
119+
} else if (keyType.equals(SupportedLangModels.cohere)) {
120+
this.cohereWrapper = new CohereAIWrapper(keyValue);
121+
}
122+
}
123+
63124
/**
64125
*
65126
* Call a remote large model to generate any text based on the received prompt.
66127
*
67128
* @param langInput flexible builder for language model parameters.
129+
*
68130
* @return string for the model response.
69-
* @throws IOException if there is an error when connecting to the OpenAI API.
70-
* @throws IllegalArgumentException if the keyType passed in the constructor is not "openai".
131+
* @throws IOException if there is an error when connecting to the
132+
* OpenAI API.
133+
* @throws IllegalArgumentException if the keyType passed in the constructor is
134+
* not "openai".
71135
*
72136
*/
73-
public String generateText(LanguageModelInput langInput) throws IOException {
74-
75-
if (this.keyType.equals("openai")) {
76-
return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(),
77-
langInput.getTemperature(), langInput.getMaxTokens());
137+
public String generateText(LanguageModelInput langInput) throws IOException {
138+
139+
if (this.keyType.equals(SupportedLangModels.openai)) {
140+
return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(),
141+
langInput.getMaxTokens());
142+
} else if (this.keyType.equals(SupportedLangModels.cohere)) {
143+
return this.generateCohereText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(),
144+
langInput.getMaxTokens());
78145
} else {
79146
throw new IllegalArgumentException("This version support openai keyType only");
80147
}
81-
148+
82149
}
83150

84151
/**
85-
* Private helper method for generating text from OpenAI GPT-3 model.
86-
*
87-
* @param model the model name, example: text-davinci-002. For more details about GPT-3 models, see: https://beta.openai.com/docs/models/gpt-3
88-
* @param prompt text of the required action or the question.
89-
* @param temperature higher values means more risks and creativity.
90-
* @param maxTokens maximum size of the model input and output.
91-
* @return string model response.
92-
* @throws IOException if there is an error when connecting to the OpenAI API.
93-
*
94-
*/
95-
private String generateOpenaiText(String model, String prompt, float temperature, int maxTokens) throws IOException {
96-
152+
* Private helper method for generating text from OpenAI GPT-3 model.
153+
*
154+
* @param model the model name, example: text-davinci-003. For more
155+
* details about GPT-3 models, see:
156+
* https://beta.openai.com/docs/models/gpt-3
157+
* @param prompt text of the required action or the question.
158+
* @param temperature higher values means more risks and creativity.
159+
* @param maxTokens maximum size of the model input and output.
160+
* @return string model response.
161+
* @throws IOException if there is an error when connecting to the OpenAI API.
162+
*
163+
*/
164+
private String generateOpenaiText(String model, String prompt, float temperature, int maxTokens)
165+
throws IOException {
166+
167+
if (model.equals(""))
168+
model = "text-davinci-003";
169+
97170
Map<String, Object> params = new HashMap<>();
98-
params.put("model", model);
99-
params.put("prompt", prompt);
100-
params.put("temperature", temperature);
101-
params.put("max_tokens", maxTokens);
102-
171+
params.put("model", model);
172+
params.put("prompt", prompt);
173+
params.put("temperature", temperature);
174+
params.put("max_tokens", maxTokens);
175+
103176
OpenaiLanguageResponse resModel = (OpenaiLanguageResponse) openaiWrapper.generateText(params);
104-
177+
105178
return resModel.getChoices().get(0).getText();
106-
179+
180+
}
181+
182+
/**
183+
* Private helper method for generating text from Cohere model.
184+
*
185+
* @param model the model name, either medium or xlarge.
186+
* @param prompt text of the required action or the question.
187+
* @param temperature higher values means more risks and creativity.
188+
* @param maxTokens maximum size of the model input and output.
189+
* @return string model response.
190+
* @throws IOException if there is an error when connecting to the API.
191+
*
192+
*/
193+
private String generateCohereText(String model, String prompt, float temperature, int maxTokens)
194+
throws IOException {
195+
196+
if (model.equals(""))
197+
model = "xlarge";
198+
199+
Map<String, Object> params = new HashMap<>();
200+
params.put("model", model);
201+
params.put("prompt", prompt);
202+
params.put("temperature", temperature);
203+
params.put("max_tokens", maxTokens);
204+
205+
CohereLanguageResponse resModel = (CohereLanguageResponse) cohereWrapper.generateText(params);
206+
207+
return resModel.getGenerations().get(0).getText();
208+
107209
}
108210
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package com.intellijava.core.model;
2+
3+
public enum SupportedLangModels {
4+
openai, cohere;
5+
}

0 commit comments

Comments
 (0)