Skip to content

Commit

Permalink
fix: enabled auth add support watsonx backend (#1190)
Browse files Browse the repository at this point in the history
Signed-off-by: Guangya Liu <[email protected]>
Signed-off-by: Alex Jones <[email protected]>
Co-authored-by: Alex Jones <[email protected]>
Co-authored-by: Matthis <[email protected]>
Signed-off-by: AlexsJones <[email protected]>
  • Loading branch information
3 people committed Oct 24, 2024
1 parent 3f070e4 commit 2fa925b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
7 changes: 5 additions & 2 deletions cmd/auth/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ var addCmd = &cobra.Command{
if strings.ToLower(backend) == "amazonbedrock" {
_ = cmd.MarkFlagRequired("providerRegion")
}
if strings.ToLower(backend) == "watsonxai" {
_ = cmd.MarkFlagRequired("providerId")
}
},
Run: func(cmd *cobra.Command, args []string) {

Expand Down Expand Up @@ -173,8 +176,8 @@ func init() {
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
//add flag for amazonbedrock region name
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, googlevertexai backend)")
//add flag for vertexAI Project ID
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai backend)")
//add flag for vertexAI/WatsonxAI Project ID
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai/watsonxai backend)")
//add flag for OCI Compartment ID
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
// add flag for openai organization
Expand Down
2 changes: 1 addition & 1 deletion pkg/ai/iai.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (p *AIProvider) GetCustomHeaders() []http.Header {
return p.CustomHeaders
}

var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci", "watsonxai"}
var passwordlessProviders = []string{"localai", "ollama", "amazonsagemaker", "amazonbedrock", "googlevertexai", "oci"}

func NeedPassword(backend string) bool {
for _, b := range passwordlessProviders {
Expand Down
13 changes: 5 additions & 8 deletions pkg/ai/watsonxai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"errors"
"fmt"
"os"

wx "github.com/IBM/watsonx-go/pkg/models"
)

Expand Down Expand Up @@ -42,20 +40,19 @@ func (c *WatsonxAIClient) Configure(config IAIConfig) error {
c.topP = config.GetTopP()
c.topK = config.GetTopK()

// WatsonxAPIKeyEnvVarName = "WATSONX_API_KEY"
// WatsonxProjectIDEnvVarName = "WATSONX_PROJECT_ID"
apiKey, projectID := os.Getenv(wx.WatsonxAPIKeyEnvVarName), os.Getenv(wx.WatsonxProjectIDEnvVarName)

apiKey := config.GetPassword()
if apiKey == "" {
return errors.New("No watsonx API key provided")
}
if projectID == "" {

projectId := config.GetProviderId()
if projectId == "" {
return errors.New("No watsonx project ID provided")
}

client, err := wx.NewClient(
wx.WithWatsonxAPIKey(apiKey),
wx.WithWatsonxProjectID(projectID),
wx.WithWatsonxProjectID(projectId),
)
if err != nil {
return fmt.Errorf("Failed to create client for testing. Error: %v", err)
Expand Down

0 comments on commit 2fa925b

Please sign in to comment.