Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(platform): Simplify Credentials UX #8524

Merged
merged 16 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,21 @@
)
openai_credentials = APIKeyCredentials(
id="53c25cb8-e3ee-465c-a4d1-e75a4c899c2a",
provider="llm",
provider="openai",
api_key=SecretStr(settings.secrets.openai_api_key),
title="Use Credits for OpenAI",
expires_at=None,
)
anthropic_credentials = APIKeyCredentials(
id="24e5d942-d9e3-4798-8151-90143ee55629",
provider="llm",
provider="anthropic",
api_key=SecretStr(settings.secrets.anthropic_api_key),
title="Use Credits for Anthropic",
expires_at=None,
)
groq_credentials = APIKeyCredentials(
id="4ec22295-8f97-4dd1-b42b-2c6957a02545",
provider="llm",
provider="groq",
api_key=SecretStr(settings.secrets.groq_api_key),
title="Use Credits for Groq",
expires_at=None,
Expand Down
11 changes: 8 additions & 3 deletions autogpt_platform/backend/backend/blocks/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
# "ollama": BlockSecret(value=""),
# }

AICredentials = CredentialsMetaInput[Literal["llm"], Literal["api_key"]]
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama"]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]

TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
provider="llm",
provider="openai",
api_key=SecretStr("mock-openai-api-key"),
title="Mock OpenAI API key",
expires_at=None,
Expand All @@ -50,8 +51,12 @@
def AICredentialsField() -> AICredentials:
return CredentialsField(
description="API key for the LLM provider.",
provider="llm",
provider=["anthropic", "groq", "openai", "ollama"],
supported_credential_types={"api_key"},
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
)


Expand Down
10 changes: 8 additions & 2 deletions autogpt_platform/backend/backend/data/model.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: this really has to be considered a tempfix. Scopes are provider-specific. So is supported_credential_types. The CredentialsField was made to handle credentials from a specific provider. If a block supports multiple providers, the discrimination logic should be handled outside the CredentialsField.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also #8524 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'd like to leave this discussion here for reference in the follow-up PR)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm happy for us to merge this as a much needed temp hotfix and we can adjust that in a follow up.

Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):


def CredentialsField(
provider: CP,
provider: CP | list[CP],
supported_credential_types: set[CT],
required_scopes: set[str] = set(),
*,
discriminator: Optional[str] = None,
discriminator_mapping: Optional[dict[str, Any]] = None,
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
Expand All @@ -167,9 +169,13 @@ def CredentialsField(
json_extra = {
k: v
for k, v in {
"credentials_provider": provider,
"credentials_provider": (
[provider] if isinstance(provider, str) else provider
),
"credentials_scopes": list(required_scopes) or None, # omit if empty
"credentials_types": list(supported_credential_types),
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
}.items()
if v is not None
}
Expand Down
33 changes: 23 additions & 10 deletions autogpt_platform/frontend/src/app/profile/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@ import { useSupabase } from "@/components/SupabaseProvider";
import { Button } from "@/components/ui/button";
import useUser from "@/hooks/useUser";
import { useRouter } from "next/navigation";
import { useCallback, useContext } from "react";
import { useCallback, useContext, useMemo } from "react";
import { FaSpinner } from "react-icons/fa";
import { Separator } from "@/components/ui/separator";
import { useToast } from "@/components/ui/use-toast";
import { IconKey, IconUser } from "@/components/ui/icons";
import { LogOutIcon, Trash2Icon } from "lucide-react";
import { providerIcons } from "@/components/integrations/credentials-input";
import {
CredentialsProviderName,
CredentialsProvidersContext,
} from "@/components/integrations/credentials-provider";
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
import {
Table,
TableBody,
Expand All @@ -23,6 +20,7 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import { CredentialsProviderName } from "@/lib/autogpt-server-api";

export default function PrivatePage() {
const { user, isLoading, error } = useUser();
Expand Down Expand Up @@ -62,7 +60,22 @@ export default function PrivatePage() {
[providers, toast],
);

if (isLoading || !providers || !providers) {
//TODO: remove when the way system credentials are handled is updated
// This contains ids for built-in "Use Credits for X" credentials
const hiddenCredentials = useMemo(
() => [
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
"53c25cb8-e3ee-465c-a4d1-e75a4c899c2a", // OpenAI
"24e5d942-d9e3-4798-8151-90143ee55629", // Anthropic
"4ec22295-8f97-4dd1-b42b-2c6957a02545", // Groq
"7f7b0654-c36b-4565-8fa7-9a52575dfae2", // D-ID
],
[],
);
Pwuts marked this conversation as resolved.
Show resolved Hide resolved

if (isLoading || !providers) {
return (
<div className="flex h-[80vh] items-center justify-center">
<FaSpinner className="mr-2 h-16 w-16 animate-spin" />
Expand All @@ -76,15 +89,15 @@ export default function PrivatePage() {
}

const allCredentials = Object.values(providers).flatMap((provider) =>
[...provider.savedOAuthCredentials, ...provider.savedApiKeys].map(
(credentials) => ({
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
.filter((cred) => !hiddenCredentials.includes(cred.id))
.map((credentials) => ({
...credentials,
provider: provider.provider,
providerName: provider.providerName,
ProviderIcon: providerIcons[provider.provider],
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
}),
),
})),
);

return (
Expand Down
58 changes: 10 additions & 48 deletions autogpt_platform/frontend/src/components/CustomNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ import {
BlockUIType,
BlockCost,
} from "@/lib/autogpt-server-api/types";
import { beautifyString, cn, setNestedProperty } from "@/lib/utils";
import {
beautifyString,
cn,
getValue,
parseKeys,
setNestedProperty,
} from "@/lib/utils";
import { Button } from "@/components/ui/button";
import { Switch } from "@/components/ui/switch";
import { history } from "./history";
Expand All @@ -36,8 +42,6 @@ import * as Separator from "@radix-ui/react-separator";
import * as ContextMenu from "@radix-ui/react-context-menu";
import { DotsVerticalIcon, TrashIcon, CopyIcon } from "@radix-ui/react-icons";

type ParsedKey = { key: string; index?: number };

export type ConnectionData = Array<{
edge_id: string;
source: string;
Expand Down Expand Up @@ -178,7 +182,7 @@ export function CustomNode({
className=""
selfKey={noteKey}
schema={noteSchema as BlockIOStringSubSchema}
value={getValue(noteKey)}
value={getValue(noteKey, data.hardcodedValues)}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
error={data.errors?.[noteKey] ?? ""}
Expand Down Expand Up @@ -228,7 +232,7 @@ export function CustomNode({
nodeId={id}
propKey={getInputPropKey(propKey)}
propSchema={propSchema}
currentValue={getValue(getInputPropKey(propKey))}
currentValue={getValue(propKey, data.hardcodedValues)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
Expand Down Expand Up @@ -283,48 +287,6 @@ export function CustomNode({
setErrors({ ...errors });
};

// Helper function to parse keys with array indices
//TODO move to utils
const parseKeys = (key: string): ParsedKey[] => {
const splits = key.split(/_@_|_#_|_\$_|\./);
const keys: ParsedKey[] = [];
let currentKey: string | null = null;

splits.forEach((split) => {
const isInteger = /^\d+$/.test(split);
if (!isInteger) {
if (currentKey !== null) {
keys.push({ key: currentKey });
}
currentKey = split;
} else {
if (currentKey !== null) {
keys.push({ key: currentKey, index: parseInt(split, 10) });
currentKey = null;
} else {
throw new Error("Invalid key format: array index without a key");
}
}
});

if (currentKey !== null) {
keys.push({ key: currentKey });
}

return keys;
};

const getValue = (key: string) => {
const keys = parseKeys(key);
return keys.reduce((acc, k) => {
if (acc === undefined) return undefined;
if (k.index !== undefined) {
return Array.isArray(acc[k.key]) ? acc[k.key][k.index] : undefined;
}
return acc[k.key];
}, data.hardcodedValues as any);
};

const isHandleConnected = (key: string) => {
return (
data.connections &&
Expand All @@ -347,7 +309,7 @@ export function CustomNode({
const handleInputClick = (key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,18 @@ export const providerIcons: Record<
CredentialsProviderName,
React.FC<{ className?: string }>
> = {
anthropic: fallbackIcon,
github: FaGithub,
google: FaGoogle,
groq: fallbackIcon,
notion: NotionLogoIcon,
discord: FaDiscord,
d_id: fallbackIcon,
google_maps: FaGoogle,
jina: fallbackIcon,
ideogram: fallbackIcon,
llm: fallbackIcon,
medium: FaMedium,
ollama: fallbackIcon,
openai: fallbackIcon,
openweathermap: fallbackIcon,
pinecone: fallbackIcon,
Expand All @@ -80,7 +82,7 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
export const CredentialsInput: FC<{
className?: string;
selectedCredentials?: CredentialsMetaInput;
onSelectCredentials: (newValue: CredentialsMetaInput) => void;
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
}> = ({ className, selectedCredentials, onSelectCredentials }) => {
const api = useMemo(() => new AutoGPTServerAPI(), []);
const credentials = useCredentials();
Expand All @@ -91,14 +93,10 @@ export const CredentialsInput: FC<{
useState<AbortController | null>(null);
const [oAuthError, setOAuthError] = useState<string | null>(null);

if (!credentials) {
if (!credentials || credentials.isLoading) {
return null;
}

if (credentials.isLoading) {
return <div>Loading...</div>;
}

ntindle marked this conversation as resolved.
Show resolved Hide resolved
const {
schema,
provider,
Expand Down Expand Up @@ -222,10 +220,21 @@ export const CredentialsInput: FC<{
</>
);

// Deselect credentials if they do not exist (e.g. provider was changed)
if (
selectedCredentials &&
!savedApiKeys
.concat(savedOAuthCredentials)
.some((c) => c.id === selectedCredentials.id)
) {
onSelectCredentials(undefined);
}

// No saved credentials yet
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) {
return (
<>
<span className="text-m green mb-0 text-gray-900">Credentials</span>
<div className={cn("flex flex-row space-x-2", className)}>
{supportsOAuth2 && (
<Button onClick={handleOAuthLogin}>
Expand All @@ -248,6 +257,25 @@ export const CredentialsInput: FC<{
);
}

const singleCredential =
savedApiKeys.length === 1 && savedOAuthCredentials.length === 0
? savedApiKeys[0]
: savedOAuthCredentials.length === 1 && savedApiKeys.length === 0
? savedOAuthCredentials[0]
: null;

if (singleCredential) {
if (!selectedCredentials) {
onSelectCredentials({
id: singleCredential.id,
type: singleCredential.type,
provider,
title: singleCredential.title,
});
}
return null;
}

function handleValueChange(newValue: string) {
if (newValue === "sign-in") {
// Trigger OAuth2 sign in flow
Expand All @@ -263,7 +291,7 @@ export const CredentialsInput: FC<{
onSelectCredentials({
id: selectedCreds.id,
type: selectedCreds.type,
provider: schema.credentials_provider,
provider: provider,
// title: customTitle, // TODO: add input for title
});
}
Expand All @@ -272,6 +300,7 @@ export const CredentialsInput: FC<{
// Saved credentials exist
return (
<>
<span className="text-m green mb-0 text-gray-900">Credentials</span>
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
<SelectTrigger>
<SelectValue placeholder={schema.placeholder} />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ const CREDENTIALS_PROVIDER_NAMES = Object.values(

// --8<-- [start:CredentialsProviderNames]
const providerDisplayNames: Record<CredentialsProviderName, string> = {
anthropic: "Anthropic",
discord: "Discord",
d_id: "D-ID",
github: "GitHub",
google: "Google",
google_maps: "Google Maps",
groq: "Groq",
ideogram: "Ideogram",
jina: "Jina",
medium: "Medium",
llm: "LLM",
notion: "Notion",
ollama: "Ollama",
openai: "OpenAI",
openweathermap: "OpenWeatherMap",
pinecone: "Pinecone",
Expand Down
Loading
Loading