Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding Custom Url Endpoints and Headers #2232

Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ trait HasCustomAuthHeader extends HasServiceParams {
}
}

trait HasCustomHeaders extends HasServiceParams {

val customHeaders = new ServiceParam[Map[String, String]](
this, "customHeaders", "Map of Custom Header Key-Value Tuples."
)

def setCustomHeaders(v: Map[String, String]): this.type = {
setScalarParam(customHeaders, v)
}

def getCustomHeaders: Map[String, String] = getScalarParam(customHeaders)
}

trait HasCustomCogServiceDomain extends Wrappable with HasURL with HasUrlPath {
def setCustomServiceName(v: String): this.type = {
setUrl(s"https://$v.cognitiveservices.azure.com/" + urlPath.stripPrefix("/"))
Expand Down Expand Up @@ -256,7 +269,15 @@ object URLEncodingUtils {
}

trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with SynapseMLLogging {
with HasCustomHeaders with SynapseMLLogging {

val customUrlRoot: Param[String] = new Param[String](
this, "customUrlRoot", "The custom URL root for the service. " +
"This will not append OpenAI specific model path completions (i.e. /chat/completions) to the URL.")

def getCustomUrlRoot: String = $(customUrlRoot)

def setCustomUrlRoot(v: String): this.type = set(customUrlRoot, v)

protected def paramNameToPayloadName(p: Param[_]): String = p match {
case p: ServiceParam[_] => p.payloadName
Expand All @@ -281,7 +302,11 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
} else {
""
}
prepareUrlRoot(row) + appended
if (get(customUrlRoot).nonEmpty) {
$(customUrlRoot)
} else {
prepareUrlRoot(row) + appended
}
}
}

Expand All @@ -296,20 +321,25 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
protected def contentType: Row => String = { _ => "application/json" }

protected def getCustomAuthHeader(row: Row): Option[String] = {
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomHeader .isEmpty && PlatformDetails.runningOnFabric()) {
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader)
if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) {
logInfo("Using Default AAD Token On Fabric")
Option(TokenLibrary.getAuthHeader)
} else {
providedCustomHeader
providedCustomAuthHeader
}
}

protected def getCustomHeaders(row: Row): Option[Map[String, String]] = {
getValueOpt(row, customHeaders)
}

protected def addHeaders(req: HttpRequestBase,
subscriptionKey: Option[String],
aadToken: Option[String],
contentType: String = "",
customAuthHeader: Option[String] = None): Unit = {
customAuthHeader: Option[String] = None,
customHeaders: Option[Map[String, String]] = None): Unit = {

if (subscriptionKey.nonEmpty) {
req.setHeader(subscriptionKeyHeaderName, subscriptionKey.get)
Expand All @@ -326,6 +356,13 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
req.setHeader("x-ms-workload-resource-moniker", UUID.randomUUID().toString)
})
}
if (customHeaders.nonEmpty) {
customHeaders.foreach(m => {
m.foreach {
case (headerName, headerValue) => req.setHeader(headerName, headerValue)
}
})
}
if (contentType != "") req.setHeader("Content-Type", contentType)
}

Expand All @@ -342,7 +379,8 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
getValueOpt(row, subscriptionKey),
getValueOpt(row, AADToken),
contentType(row),
getCustomAuthHeader(row))
getCustomAuthHeader(row),
getCustomHeaders(row))

req match {
case er: HttpEntityEnclosingRequestBase =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ trait HasPromptInputs extends HasServiceParams {
trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = true)
this, "deploymentName", "The name of the deployment", isRequired = false)

def getDeploymentName: String = getScalarParam(deploymentName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,29 @@ class OpenAIChatCompletionSuite extends TransformerFuzzing[OpenAIChatCompletion]
assert(Option(results.apply(2).getAs[Row]("out")).isEmpty)
}

ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")

val customEndpointCompletion = new OpenAIChatCompletion()
.setCustomUrlRoot(customRootUrlValue)
.setOutputCol("out")
.setMessagesCol("messages")
.setTemperature(0)

if (accessToken.isEmpty) {
customEndpointCompletion.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
} else {
customEndpointCompletion.setAADToken(accessToken)
.setCustomHeaders(Map("X-ModelType" -> "gpt-4-turbo-chat-completions",
"X-ScenarioGUID" -> "7687c733-45b0-425b-82b3-05eb4eb70247"))
}

testCompletion(customEndpointCompletion, goodDf)
}

def testCompletion(completion: OpenAIChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = {
val fromRow = ChatCompletionResponse.makeFromRowConverter
completion.transform(df).collect().foreach(r =>
Expand Down
Loading