-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* 1 * 2 * 3 * Implementing llama, added endpoint, client, completion request, response,and controller
- Loading branch information
1 parent
f6f813d
commit e24c63b
Showing
6 changed files
with
407 additions
and
0 deletions.
There are no files selected for viewing
170 changes: 170 additions & 0 deletions
170
...pring/edgechain-app/src/main/java/com/edgechain/lib/endpoint/impl/llm/Llama2Endpoint.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
package com.edgechain.lib.endpoint.impl.llm; | ||
|
||
import com.edgechain.lib.endpoint.Endpoint; | ||
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; | ||
import com.edgechain.lib.openai.response.ChatCompletionResponse; | ||
import com.edgechain.lib.request.ArkRequest; | ||
import com.edgechain.lib.retrofit.Llama2Service; | ||
import com.edgechain.lib.retrofit.client.RetrofitClientInstance; | ||
import com.edgechain.lib.rxjava.retry.RetryPolicy; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import io.reactivex.rxjava3.core.Observable; | ||
import org.json.JSONObject; | ||
import org.modelmapper.ModelMapper; | ||
import retrofit2.Retrofit; | ||
|
||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public class Llama2Endpoint extends Endpoint { | ||
private final Retrofit retrofit = RetrofitClientInstance.getInstance(); | ||
private final Llama2Service llama2Service = retrofit.create(Llama2Service.class); | ||
|
||
private final ModelMapper modelMapper = new ModelMapper(); | ||
|
||
private String inputs; | ||
private JSONObject parameters; | ||
private Double temperature; | ||
@JsonProperty("top_k") | ||
private Integer topK; | ||
@JsonProperty("top_p") | ||
private Double topP; | ||
|
||
@JsonProperty("do_sample") | ||
private Boolean doSample; | ||
@JsonProperty("max_new_tokens") | ||
private Integer maxNewTokens; | ||
@JsonProperty("repetition_penalty") | ||
private Double repetitionPenalty; | ||
private List<String> stop; | ||
private String chainName; | ||
private String callIdentifier; | ||
|
||
public Llama2Endpoint() { | ||
} | ||
|
||
public Llama2Endpoint(String url, RetryPolicy retryPolicy, | ||
Double temperature, Integer topK, Double topP, | ||
Boolean doSample, Integer maxNewTokens, Double repetitionPenalty, | ||
List<String> stop) { | ||
super(url, retryPolicy); | ||
this.temperature = temperature; | ||
this.topK = topK; | ||
this.topP = topP; | ||
this.doSample = doSample; | ||
this.maxNewTokens = maxNewTokens; | ||
this.repetitionPenalty = repetitionPenalty; | ||
this.stop = stop; | ||
} | ||
|
||
public Llama2Endpoint(String url, RetryPolicy retryPolicy) { | ||
super(url, retryPolicy); | ||
this.temperature = 0.7; | ||
this.maxNewTokens = 512; | ||
} | ||
|
||
public String getInputs() { | ||
return inputs; | ||
} | ||
|
||
public void setInputs(String inputs) { | ||
this.inputs = inputs; | ||
} | ||
|
||
public JSONObject getParameters() { | ||
return parameters; | ||
} | ||
|
||
public void setParameters(JSONObject parameters) { | ||
this.parameters = parameters; | ||
} | ||
|
||
public Double getTemperature() { | ||
return temperature; | ||
} | ||
|
||
public void setTemperature(Double temperature) { | ||
this.temperature = temperature; | ||
} | ||
|
||
public Integer getTopK() { | ||
return topK; | ||
} | ||
|
||
public void setTopK(Integer topK) { | ||
this.topK = topK; | ||
} | ||
|
||
public Double getTopP() { | ||
return topP; | ||
} | ||
|
||
public void setTopP(Double topP) { | ||
this.topP = topP; | ||
} | ||
|
||
public Boolean getDoSample() { | ||
return doSample; | ||
} | ||
|
||
public void setDoSample(Boolean doSample) { | ||
this.doSample = doSample; | ||
} | ||
|
||
public Integer getMaxNewTokens() { | ||
return maxNewTokens; | ||
} | ||
|
||
public void setMaxNewTokens(Integer maxNewTokens) { | ||
this.maxNewTokens = maxNewTokens; | ||
} | ||
|
||
public Double getRepetitionPenalty() { | ||
return repetitionPenalty; | ||
} | ||
|
||
public void setRepetitionPenalty(Double repetitionPenalty) { | ||
this.repetitionPenalty = repetitionPenalty; | ||
} | ||
|
||
public List<String> getStop() { | ||
return stop; | ||
} | ||
|
||
public void setStop(List<String> stop) { | ||
this.stop = stop; | ||
} | ||
|
||
public String getChainName() { | ||
return chainName; | ||
} | ||
|
||
public void setChainName(String chainName) { | ||
this.chainName = chainName; | ||
} | ||
|
||
public String getCallIdentifier() { | ||
return callIdentifier; | ||
} | ||
|
||
public void setCallIdentifier(String callIdentifier) { | ||
this.callIdentifier = callIdentifier; | ||
} | ||
|
||
public Observable<List<Llama2ChatCompletionResponse>> chatCompletion( | ||
String inputs,String chainName, ArkRequest arkRequest) { | ||
|
||
Llama2Endpoint mapper = modelMapper.map(this, Llama2Endpoint.class); | ||
mapper.setInputs(inputs); | ||
mapper.setChainName(chainName); | ||
return chatCompletion(mapper, arkRequest); | ||
} | ||
|
||
private Observable<List<Llama2ChatCompletionResponse>> chatCompletion(Llama2Endpoint mapper, ArkRequest arkRequest) { | ||
|
||
if (Objects.nonNull(arkRequest)) mapper.setCallIdentifier(arkRequest.getRequestURI()); | ||
else mapper.setCallIdentifier("URI wasn't provided"); | ||
|
||
return Observable.fromSingle(this.llama2Service.chatCompletion(mapper)); | ||
} | ||
} |
67 changes: 67 additions & 0 deletions
67
FlySpring/edgechain-app/src/main/java/com/edgechain/lib/llama2/Llama2Client.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package com.edgechain.lib.llama2; | ||
|
||
|
||
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; | ||
import com.edgechain.lib.llama2.request.Llama2ChatCompletionRequest; | ||
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; | ||
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain; | ||
import com.edgechain.lib.utils.JsonUtils; | ||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import io.reactivex.rxjava3.core.Observable; | ||
import org.apache.commons.lang3.StringUtils; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import org.springframework.beans.factory.annotation.Autowired; | ||
import org.springframework.http.*; | ||
import org.springframework.stereotype.Service; | ||
import org.springframework.web.client.RestTemplate; | ||
|
||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
@Service | ||
public class Llama2Client { | ||
@Autowired | ||
private ObjectMapper objectMapper; | ||
private final Logger logger = LoggerFactory.getLogger(getClass()); | ||
private final RestTemplate restTemplate = new RestTemplate(); | ||
public EdgeChain<List<Llama2ChatCompletionResponse>> createChatCompletion( | ||
Llama2ChatCompletionRequest request, Llama2Endpoint endpoint) { | ||
return new EdgeChain<>( | ||
Observable.create( | ||
emitter -> { | ||
try { | ||
|
||
logger.info("Logging ChatCompletion...."); | ||
|
||
logger.info("==============REQUEST DATA================"); | ||
logger.info(request.toString()); | ||
|
||
// Llama2ChatCompletionRequest llamaRequest = new Llama2ChatCompletionRequest(); | ||
// | ||
// llamaRequest.setInputs(request.getInputs()); | ||
// llamaRequest.setParameters(request.getParameters()); | ||
|
||
|
||
|
||
// Create headers | ||
HttpHeaders headers = new HttpHeaders(); | ||
headers.setContentType(MediaType.APPLICATION_JSON); | ||
HttpEntity<Llama2ChatCompletionRequest> entity = new HttpEntity<>(request, headers); | ||
// | ||
String response = restTemplate.postForObject(endpoint.getUrl(), entity, String.class); | ||
|
||
List<Llama2ChatCompletionResponse> chatCompletionResponse = | ||
objectMapper.readValue(response, new TypeReference<List<Llama2ChatCompletionResponse>>() {}); | ||
emitter.onNext(chatCompletionResponse); | ||
emitter.onComplete(); | ||
|
||
} catch (final Exception e) { | ||
emitter.onError(e); | ||
} | ||
}), | ||
endpoint); | ||
} | ||
|
||
} |
80 changes: 80 additions & 0 deletions
80
...chain-app/src/main/java/com/edgechain/lib/llama2/request/Llama2ChatCompletionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package com.edgechain.lib.llama2.request; | ||
|
||
import com.edgechain.lib.openai.request.ChatCompletionRequest; | ||
import com.edgechain.lib.openai.request.ChatMessage; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import org.json.JSONObject; | ||
|
||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.StringJoiner; | ||
|
||
public class Llama2ChatCompletionRequest { | ||
|
||
private String inputs; | ||
private JSONObject parameters; | ||
|
||
public Llama2ChatCompletionRequest() { | ||
} | ||
|
||
public Llama2ChatCompletionRequest(String inputs, JSONObject parameters) { | ||
this.inputs = inputs; | ||
this.parameters = parameters; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return new StringJoiner(", ", Llama2ChatCompletionRequest.class.getSimpleName() + "[{", "}]") | ||
.add("\"inputs:\"" + inputs) | ||
.add("\"parameters:\"" + parameters) | ||
.toString(); | ||
} | ||
|
||
public static Llama2ChatCompletionRequestBuilder builder() { | ||
return new Llama2ChatCompletionRequestBuilder(); | ||
} | ||
|
||
|
||
public String getInputs() { | ||
return inputs; | ||
} | ||
|
||
public void setInputs(String inputs) { | ||
this.inputs = inputs; | ||
} | ||
|
||
public JSONObject getParameters() { | ||
return parameters; | ||
} | ||
|
||
public void setParameters(JSONObject parameters) { | ||
this.parameters = parameters; | ||
} | ||
|
||
|
||
public static class Llama2ChatCompletionRequestBuilder { | ||
private String inputs; | ||
private JSONObject parameters; | ||
|
||
private Llama2ChatCompletionRequestBuilder() { | ||
} | ||
|
||
|
||
public Llama2ChatCompletionRequestBuilder inputs(String inputs) { | ||
this.inputs = inputs; | ||
return this; | ||
} | ||
|
||
public Llama2ChatCompletionRequestBuilder parameters(JSONObject parameters) { | ||
this.parameters = parameters; | ||
return this; | ||
} | ||
|
||
public Llama2ChatCompletionRequest build() { | ||
return new Llama2ChatCompletionRequest( | ||
inputs, | ||
parameters); | ||
} | ||
} | ||
} |
18 changes: 18 additions & 0 deletions
18
...ain-app/src/main/java/com/edgechain/lib/llama2/response/Llama2ChatCompletionResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package com.edgechain.lib.llama2.response; | ||
|
||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
|
||
public class Llama2ChatCompletionResponse { | ||
@JsonProperty("generated_text") | ||
private String generatedText; | ||
|
||
public Llama2ChatCompletionResponse() {} | ||
|
||
public String getGeneratedText() { | ||
return generatedText; | ||
} | ||
|
||
public void setGeneratedText(String generatedText) { | ||
this.generatedText = generatedText; | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
FlySpring/edgechain-app/src/main/java/com/edgechain/lib/retrofit/Llama2Service.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package com.edgechain.lib.retrofit; | ||
|
||
import com.edgechain.lib.endpoint.impl.llm.Llama2Endpoint; | ||
import com.edgechain.lib.llama2.response.Llama2ChatCompletionResponse; | ||
import io.reactivex.rxjava3.core.Single; | ||
import retrofit2.http.Body; | ||
import retrofit2.http.POST; | ||
|
||
import java.util.List; | ||
|
||
public interface Llama2Service { | ||
@POST(value = "llama2/chat-completion") | ||
Single<List<Llama2ChatCompletionResponse>> chatCompletion(@Body Llama2Endpoint llama2Endpoint); | ||
} |
Oops, something went wrong.