Skip to content

Commit 3c2a6eb

Browse files
committed
add support Cohere API wrapper
- implement CohereAIWrapper. - update the configurations to support Cohere requirements. - implement supporting classes like unit testing and response handler. - Rename old files to become more flexible with multi model vision. - clean API string.
1 parent cd8baf5 commit 3c2a6eb

File tree

7 files changed

+175
-24
lines changed

7 files changed

+175
-24
lines changed

core/com.intellijava.core/src/main/java/com/intellijava/core/model/CohereTextResponse.java renamed to core/com.intellijava.core/src/main/java/com/intellijava/core/model/CohereLanguageResponse.java

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
* @author github.com/Barqawiz
1010
*
1111
*/
12-
public class CohereTextResponse {
12+
public class CohereLanguageResponse extends BaseRemoteModel{
1313

1414
/** A unique identifier for the response.*/
15-
private String id;
1615
private List<Generation> generations;
1716
private String prompt;
1817

@@ -57,24 +56,6 @@ public void setText(String text) {
5756
}
5857
}
5958

60-
/**
61-
* Get the unique identifier for the response.
62-
*
63-
* @return the unique identifier for the response.
64-
*/
65-
public String getId() {
66-
return id;
67-
}
68-
69-
/**
70-
* Sets the unique identifier for the response.
71-
*
72-
* @param id the unique identifier for the response.
73-
*/
74-
public void setId(String id) {
75-
this.id = id;
76-
}
77-
7859
/**
7960
* Get the list of generated texts.
8061
*

core/com.intellijava.core/src/main/java/com/intellijava/core/model/input/LanguageModelInput.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
*
55
* LanguageModelInput handle the input parameters for the majority of the remote language models.
66
*
7+
* Language models documentations:
8+
* - Openai : https://beta.openai.com/docs/api-reference/completions.
9+
* - Cohere : https://docs.cohere.ai/reference/generate
10+
*
711
* @author github.com/Barqawiz
812
*
913
*/
@@ -45,7 +49,11 @@ public Builder(String prompt) {
4549

4650
/**
4751
* Setter for model.
48-
* @param model the model name. The largest OpenAI model is text-davinci-002.
52+
* @param model the model name.
53+
*
54+
* The largest OpenAI model is text-davinci-003.
55+
* The largest cohere model is xlarge.
56+
*
4957
* @return instance of Builder
5058
*/
5159
public Builder setModel(String model) {
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package com.intellijava.core.wrappers;
2+
3+
import java.io.IOException;
4+
import java.io.OutputStream;
5+
import java.net.HttpURLConnection;
6+
import java.net.URL;
7+
import java.nio.charset.StandardCharsets;
8+
import java.util.Map;
9+
10+
import com.intellijava.core.model.BaseRemoteModel;
11+
import com.intellijava.core.model.CohereLanguageResponse;
12+
import com.intellijava.core.model.OpenaiLanguageResponse;
13+
import com.intellijava.core.utils.Config2;
14+
import com.intellijava.core.utils.ConnHelper;
15+
16+
/**
17+
*
18+
*
19+
* @author github.com/Barqawiz
20+
*
21+
*/
22+
public class CohereAIWrapper implements LanguageModelInterface{
23+
24+
private final String API_BASE_URL = Config2.getInstance().getProperty("url.cohere.base");
25+
private final String COHERE_VERSION = Config2.getInstance().getProperty("url.cohere.version");
26+
private String API_KEY;
27+
28+
/**
29+
* CohereAIWrapper constructor with the API key
30+
*
31+
* @param apiKey cohere API key, generate if from your account.
32+
*/
33+
public CohereAIWrapper(String apiKey) {
34+
this.API_KEY = apiKey;
35+
}
36+
37+
/**
38+
*
39+
* Generate text from remote large language model based on the received prompt.
40+
*
41+
* @param params key and value for the API parameters
42+
* model the model name, either medium or xlarge.
43+
* prompt text of the required action or the question.
44+
* temperature higher values means more risks and creativity.
45+
* max_tokens maximum size of the model input and output.
46+
* @return BaseRemoteModel for model response
47+
* @throws IOException if there is an error when connecting to the OpenAI API.
48+
*/
49+
@Override
50+
public BaseRemoteModel generateText(Map<String, Object> params) throws IOException {
51+
52+
String url = API_BASE_URL + Config2.getInstance().getProperty("url.cohere.completions");
53+
54+
String json = ConnHelper.convertMaptToJson(params);
55+
56+
HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
57+
connection.setRequestMethod("POST");
58+
connection.setRequestProperty("Content-Type", "application/json");
59+
connection.setRequestProperty("Authorization", "Bearer " + API_KEY);
60+
connection.setRequestProperty("Cohere-Version", COHERE_VERSION);
61+
connection.setDoOutput(true);
62+
63+
try (OutputStream outputStream = connection.getOutputStream()) {
64+
outputStream.write(json.getBytes(StandardCharsets.UTF_8));
65+
}
66+
67+
if (connection.getResponseCode() != HttpURLConnection.HTTP_OK) {
68+
String errorMessage = ConnHelper.getErrorMessage(connection);
69+
throw new IOException(errorMessage);
70+
}
71+
72+
// get the response and convert to model
73+
CohereLanguageResponse resModel = ConnHelper.convertSteamToModel(connection.getInputStream(), CohereLanguageResponse.class);
74+
return resModel;
75+
}
76+
77+
}

core/com.intellijava.core/src/main/resources/config.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ url.openai.testkey=
55
url.cohere.base=https://api.cohere.ai
66
url.cohere.completions=/generate
77
url.cohere.version=2022-12-06
8-
url.cohere.testkey=SWcvjpq7tCetHIuNaQL35CWXBOr4WIkaOR7EfmZ8
8+
url.cohere.testkey=
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package com.intellijava.core;
2+
3+
import static org.junit.Assert.fail;
4+
5+
import java.io.IOException;
6+
import java.util.HashMap;
7+
import java.util.List;
8+
import java.util.Map;
9+
import org.junit.Test;
10+
import com.intellijava.core.model.CohereLanguageResponse;
11+
import com.intellijava.core.model.CohereLanguageResponse.Generation;
12+
import com.intellijava.core.utils.Config2;
13+
import com.intellijava.core.wrappers.CohereAIWrapper;
14+
15+
public class CohereModelConnectionTest {
16+
17+
/**
18+
* coherKey - change the coherKey
19+
*/
20+
private final String coherKey = Config2.getInstance().getProperty("url.cohere.testkey");
21+
22+
23+
@Test
24+
public void testLanguageWrapper() {
25+
26+
// prepare the object
27+
CohereAIWrapper cohereWrapper = new CohereAIWrapper(coherKey);
28+
29+
// prepare the prompt with training data
30+
String targetIndustryIdea = "electric cars";
31+
String prompt = "This program generates startup idea and name given the industry." +
32+
"\n\nIndustry: Workplace" +
33+
"\nStartup Idea: A platform that generates slide deck contents automatically based on a given outline" +
34+
"\nStartup Name: Deckerize" +
35+
"\n--" +
36+
"\nIndustry: Home Decor" +
37+
"\nStartup Idea: An app that calculates the best position of your indoor plants for your apartment" +
38+
"\nStartup Name: Planteasy" +
39+
"\n--" +
40+
"\nIndustry: Healthcare" +
41+
"\nStartup Idea: A hearing aid for the elderly that automatically adjusts its levels and with a battery lasting a whole week" +
42+
"\nStartup Name: Hearspan" +
43+
"\n\n--" +
44+
"\nIndustry: Education" +
45+
"\nStartup Idea: An online school that lets students mix and match their own curriculum based on their interests and goals" +
46+
"\nStartup Name: Prime Age" +
47+
"\n\n--" +
48+
"\nIndustry: {industry}".replace("{industry}", targetIndustryIdea);
49+
50+
// prepare the input parameters
51+
Map<String, Object> params = new HashMap<>();
52+
params.put("prompt", prompt);
53+
params.put("model", "xlarge");
54+
params.put("max_tokens", 40);
55+
params.put("truncate", "END");
56+
params.put("return_likelihoods", "NONE");
57+
58+
59+
60+
// call the API
61+
try {
62+
if (coherKey.isBlank()) return;
63+
64+
CohereLanguageResponse resModel = (CohereLanguageResponse) cohereWrapper.generateText(params);
65+
66+
List<Generation> responses = resModel.getGenerations();
67+
68+
69+
assert responses.size() > 0;
70+
71+
for (Generation data: responses) {
72+
System.out.println(data.getText().toString());
73+
}
74+
75+
76+
} catch (IOException e) {
77+
if (coherKey.isBlank()) {
78+
System.out.print("testLanguageWrapper: set the API key to run the test case.");
79+
} else {
80+
fail("testLanguageWrapper failed with exception: " + e.getMessage());
81+
}
82+
}
83+
}
84+
85+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
/**
3737
* Unit test for Remote Language Model
3838
*/
39-
public class RemoteModelConnectionTest {
39+
public class OpenaiModelConnectionTest {
4040

4141
/**
4242
* openaiKey - change the openaiKey

core/com.intellijava.core/target/classes/config.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ url.openai.testkey=
55
url.cohere.base=https://api.cohere.ai
66
url.cohere.completions=/generate
77
url.cohere.version=2022-12-06
8-
url.cohere.testkey=SWcvjpq7tCetHIuNaQL35CWXBOr4WIkaOR7EfmZ8
8+
url.cohere.testkey=

0 commit comments

Comments
 (0)