From 626eb30e77318a5e2163df939745a36c95803b8b Mon Sep 17 00:00:00 2001 From: Saxon Fletcher Date: Mon, 29 Sep 2025 13:57:36 +1000 Subject: [PATCH] Assistant action orientated approach (#38806) * update onboarding * update model and fix part issue * action orientated assistant * fix tool * lock * remove unused filter * fix tests * fix again * update package * update container * fix tests * refactor(ai assistant): break out message markdown and profile picture * wip * refactor(ai assistant): break up message component * refactor: break ai assistant message down into multiple files * refactor: simplify ReportBlock state * fix: styling of draggable report block header When the drag handle is showing, it overlaps with the block header. Decrease the opacity of the header so the handle can be seen and the two can be distinguished. * fix: minor tweaks to tool ui * refactor: simplify DisplayBlockRenderer state * fix: remove double deploy button in edge function block When the confirm footer is shown, the deploy button on the top right should be hidden (not just disabled) to avoid confusion. * refactor, test: message sanitization by opt-in level Refactor the message sanitization to have more type safety and be more testable. Add tests to ensure: - Message sanitization always runs on generate-v4 - Message sanitization correctly works by opt-in level * Fix conflicts in pnpm lock * Couple of nits and refactors * Revert casing for report block snippet * adjust sanitised prompt * Fix tests --------- Co-authored-by: Charis Lam <26616127+charislam@users.noreply.github.com> Co-authored-by: Joshen Lim --- .../Replication/DeleteDestination.tsx | 2 +- .../interfaces/HomeNew/SnippetDropdown.tsx | 2 +- .../Reports/ReportBlock/ReportBlock.tsx | 166 ++++--- .../ReportBlock/ReportBlockContainer.tsx | 58 ++- .../components/interfaces/UserDropdown.tsx | 38 +- .../ui/AIAssistantPanel/AIAssistant.tsx | 50 +- .../ui/AIAssistantPanel/AIOnboarding.tsx | 22 +- .../ui/AIAssistantPanel/ConfirmFooter.tsx | 41 ++ .../AIAssistantPanel/DisplayBlockRenderer.tsx | 218 ++++++--- .../AIAssistantPanel/EdgeFunctionRenderer.tsx | 158 +++++++ .../ui/AIAssistantPanel/Message.Actions.tsx | 46 ++ .../ui/AIAssistantPanel/Message.Context.tsx | 62 +++ .../ui/AIAssistantPanel/Message.Display.tsx | 90 ++++ .../ui/AIAssistantPanel/Message.Parts.tsx | 327 +++++++++++++ .../ui/AIAssistantPanel/Message.tsx | 442 ++++-------------- .../ui/AIAssistantPanel/Message.utils.ts | 124 +++-- .../ui/AIAssistantPanel/MessageMarkdown.tsx | 209 ++------- .../AIAssistantPanel/elements/Reasoning.tsx | 57 --- .../ui/AIAssistantPanel/elements/Tool.tsx | 54 +++ .../EdgeFunctionBlock/EdgeFunctionBlock.tsx | 170 +++---- .../components/ui/EditorPanel/EditorPanel.tsx | 2 +- .../components/ui/QueryBlock/QueryBlock.tsx | 421 ++++++----------- apps/studio/components/ui/SchemaSelector.tsx | 2 +- .../components/ui/SqlWarningAdmonition.tsx | 36 +- apps/studio/hooks/misc/useChanged.ts | 8 + apps/studio/lib/ai/message-utils.ts | 25 + apps/studio/lib/ai/prompts.ts | 46 +- apps/studio/lib/ai/test-fixtures.ts | 114 +++++ apps/studio/lib/ai/tool-filter.test.ts | 44 +- apps/studio/lib/ai/tool-filter.ts | 15 +- apps/studio/lib/ai/tools/mcp-tools.ts | 25 +- apps/studio/lib/ai/tools/rendering-tools.ts | 44 +- .../lib/ai/tools/tool-sanitizer.test.ts | 175 +++++++ apps/studio/lib/ai/tools/tool-sanitizer.ts | 54 +++ apps/studio/lib/api/generate-v4.test.ts | 77 +++ apps/studio/lib/profile.tsx | 27 ++ apps/studio/pages/api/ai/sql/generate-v4.ts | 50 +- apps/studio/state/ai-assistant-state.tsx | 65 +-- 38 files changed, 2175 insertions(+), 1391 deletions(-) create mode 100644 apps/studio/components/ui/AIAssistantPanel/ConfirmFooter.tsx create mode 100644 apps/studio/components/ui/AIAssistantPanel/EdgeFunctionRenderer.tsx create mode 100644 apps/studio/components/ui/AIAssistantPanel/Message.Actions.tsx create mode 100644 apps/studio/components/ui/AIAssistantPanel/Message.Context.tsx create mode 100644 apps/studio/components/ui/AIAssistantPanel/Message.Display.tsx create mode 100644 apps/studio/components/ui/AIAssistantPanel/Message.Parts.tsx delete mode 100644 apps/studio/components/ui/AIAssistantPanel/elements/Reasoning.tsx create mode 100644 apps/studio/components/ui/AIAssistantPanel/elements/Tool.tsx create mode 100644 apps/studio/lib/ai/message-utils.ts create mode 100644 apps/studio/lib/ai/test-fixtures.ts create mode 100644 apps/studio/lib/ai/tools/tool-sanitizer.test.ts create mode 100644 apps/studio/lib/ai/tools/tool-sanitizer.ts create mode 100644 apps/studio/lib/api/generate-v4.test.ts diff --git a/apps/studio/components/interfaces/Database/Replication/DeleteDestination.tsx b/apps/studio/components/interfaces/Database/Replication/DeleteDestination.tsx index 1c1f633505..71c0702381 100644 --- a/apps/studio/components/interfaces/Database/Replication/DeleteDestination.tsx +++ b/apps/studio/components/interfaces/Database/Replication/DeleteDestination.tsx @@ -21,7 +21,7 @@ export const DeleteDestination = ({ visible={visible} loading={isLoading} title="Delete this destination" - confirmLabel={isLoading ? 'Deleting…' : `Delete destination`} + confirmLabel={isLoading ? 'Deleting...' : `Delete destination`} confirmPlaceholder="Type in name of destination" confirmString={name ?? 'Unknown'} text={`This will delete the destination "${name}"`} diff --git a/apps/studio/components/interfaces/HomeNew/SnippetDropdown.tsx b/apps/studio/components/interfaces/HomeNew/SnippetDropdown.tsx index b9f8fc56ab..70b475dc94 100644 --- a/apps/studio/components/interfaces/HomeNew/SnippetDropdown.tsx +++ b/apps/studio/components/interfaces/HomeNew/SnippetDropdown.tsx @@ -85,7 +85,7 @@ export const SnippetDropdown = ({ /> {isLoading ? ( - Loading… + Loading... ) : snippets.length === 0 ? ( No snippets found ) : null} diff --git a/apps/studio/components/interfaces/Reports/ReportBlock/ReportBlock.tsx b/apps/studio/components/interfaces/Reports/ReportBlock/ReportBlock.tsx index 103060a2bf..e70d6e2ae3 100644 --- a/apps/studio/components/interfaces/Reports/ReportBlock/ReportBlock.tsx +++ b/apps/studio/components/interfaces/Reports/ReportBlock/ReportBlock.tsx @@ -1,4 +1,6 @@ import { X } from 'lucide-react' +import { useCallback, useState } from 'react' +import { toast } from 'sonner' import { useParams } from 'common' import { ChartConfig } from 'components/interfaces/SQLEditor/UtilityPanel/ChartConfig' @@ -6,10 +8,11 @@ import { ButtonTooltip } from 'components/ui/ButtonTooltip' import { DEFAULT_CHART_CONFIG, QueryBlock } from 'components/ui/QueryBlock/QueryBlock' import { AnalyticsInterval } from 'data/analytics/constants' import { useContentIdQuery } from 'data/content/content-id-query' +import { usePrimaryDatabase } from 'data/read-replicas/replicas-query' +import { useExecuteSqlMutation } from 'data/sql/execute-sql-mutation' +import { useChangedSync } from 'hooks/misc/useChanged' import { useDatabaseSelectorStateSnapshot } from 'state/database-selector' -import { Dashboards, SqlSnippets } from 'types' -import { Button, cn } from 'ui' -import ShimmeringLoader from 'ui-patterns/ShimmeringLoader' +import type { Dashboards, SqlSnippets } from 'types' import { DEPRECATED_REPORTS } from '../Reports.constants' import { ChartBlock } from './ChartBlock' import { DeprecatedChartBlock } from './DeprecatedChartBlock' @@ -46,7 +49,7 @@ export const ReportBlock = ({ const isSnippet = item.attribute.startsWith('snippet_') - const { data, error, isLoading, isError } = useContentIdQuery( + const { data, error, isLoading } = useContentIdQuery( { projectRef, id: item.id }, { enabled: isSnippet && !!item.id, @@ -57,6 +60,11 @@ export const ReportBlock = ({ if (failureCount >= 2) return false return true }, + onSuccess: (contentData) => { + if (!isSnippet) return + const fetchedSql = (contentData?.content as SqlSnippets.Content | undefined)?.sql + if (fetchedSql) runQuery('select', fetchedSql) + }, } ) const sql = isSnippet ? (data?.content as SqlSnippets.Content)?.sql : undefined @@ -64,82 +72,102 @@ export const ReportBlock = ({ const isDeprecatedChart = DEPRECATED_REPORTS.includes(item.attribute) const snippetMissing = error?.message.includes('Content not found') + const { database: primaryDatabase } = usePrimaryDatabase({ projectRef }) + const readOnlyConnectionString = primaryDatabase?.connection_string_read_only + const postgresConnectionString = primaryDatabase?.connectionString + + const [rows, setRows] = useState(undefined) + const [isWriteQuery, setIsWriteQuery] = useState(false) + + const { + mutate: executeSql, + error: executeSqlError, + isLoading: executeSqlLoading, + } = useExecuteSqlMutation({ + onError: () => { + // Silence the error toast because the error will be displayed inline + }, + }) + + const runQuery = useCallback( + (queryType: 'select' | 'mutation' = 'select', sqlToRun?: string) => { + if (!projectRef || !sqlToRun) return false + + const connectionString = + queryType === 'mutation' + ? postgresConnectionString + : readOnlyConnectionString ?? postgresConnectionString + + if (!connectionString) { + toast.error('Unable to establish a database connection for this project.') + return false + } + + if (queryType === 'mutation') { + setIsWriteQuery(true) + } + executeSql( + { projectRef, connectionString, sql: sqlToRun }, + { + onSuccess: (data) => { + setRows(data.result) + setIsWriteQuery(queryType === 'mutation') + }, + onError: (mutationError) => { + const lowerMessage = mutationError.message.toLowerCase() + const isReadOnlyError = + lowerMessage.includes('read-only transaction') || + lowerMessage.includes('permission denied') || + lowerMessage.includes('must be owner') + + if (queryType === 'select' && isReadOnlyError) { + setIsWriteQuery(true) + } + }, + } + ) + return true + }, + [projectRef, readOnlyConnectionString, postgresConnectionString, executeSql] + ) + + const sqlHasChanged = useChangedSync(sql) + const isRefreshingChanged = useChangedSync(isRefreshing) + if (sqlHasChanged || (isRefreshingChanged && isRefreshing)) { + runQuery('select', sql) + } + return ( <> {isSnippet ? ( } - className="w-7 h-7" - onClick={() => onRemoveChart({ metric: { key: item.attribute } })} - tooltip={{ content: { side: 'bottom', text: 'Remove chart' } }} - /> - } - onUpdateChartConfig={onUpdateChart} - noResultPlaceholder={ -
- {isLoading ? ( - <> - - - - - ) : isError ? ( - <> -

- {snippetMissing ? 'SQL snippet cannot be found' : 'Error fetching SQL snippet'} -

-

- {snippetMissing ? 'Please remove this block from your report' : error.message} -

- - ) : ( - <> -

- No results returned from query -

-

- Results from the SQL query can be viewed as a table or chart here -

- - )} -
- } - readOnlyErrorPlaceholder={ -
-

- SQL query is not read-only and cannot be rendered -

-

- Queries that involve any mutation will not be run in reports -

- -
+ tooltip={{ content: { side: 'bottom', text: 'Remove chart' } }} + /> + ) } + onExecute={(queryType) => { + runQuery(queryType, sql) + }} + onUpdateChartConfig={onUpdateChart} + onRemoveChart={() => onRemoveChart({ metric: { key: item.attribute } })} + disabled={isLoading || snippetMissing || !sql} /> ) : isDeprecatedChart ? (
-
- {loading ? ( - - ) : ( - icon - )} -
- {showDragHandle && ( + {showDragHandle ? (
+ ) : icon ? ( + icon + ) : ( + )} -

- {label} -

+

{label}

+ {badge &&
{badge}
} +
{actions}
@@ -77,8 +73,20 @@ export const ReportBlockContainer = ({ )} -
- {children} +
+
+ {children} +
) diff --git a/apps/studio/components/interfaces/UserDropdown.tsx b/apps/studio/components/interfaces/UserDropdown.tsx index 86dafac63b..1fe53c346d 100644 --- a/apps/studio/components/interfaces/UserDropdown.tsx +++ b/apps/studio/components/interfaces/UserDropdown.tsx @@ -4,12 +4,10 @@ import Link from 'next/link' import { useRouter } from 'next/router' import { ProfileImage } from 'components/ui/ProfileImage' -import { useProfileIdentitiesQuery } from 'data/profile/profile-identities-query' import { useIsFeatureEnabled } from 'hooks/misc/useIsFeatureEnabled' import { useSignOut } from 'lib/auth' import { IS_PLATFORM } from 'lib/constants' -import { getGitHubProfileImgUrl } from 'lib/github' -import { useProfile } from 'lib/profile' +import { useProfileNameAndPicture } from 'lib/profile' import { useAppStateSnapshot } from 'state/app-state' import { Button, @@ -30,40 +28,28 @@ import { useFeaturePreviewModal } from './App/FeaturePreview/FeaturePreviewConte export function UserDropdown() { const router = useRouter() - const signOut = useSignOut() - const { profile, isLoading: isLoadingProfile } = useProfile() const { theme, setTheme } = useTheme() const appStateSnapshot = useAppStateSnapshot() + const profileShowEmailEnabled = useIsFeatureEnabled('profile:show_email') + const { username, avatarUrl, primaryEmail, isLoading } = useProfileNameAndPicture() + + const signOut = useSignOut() const setCommandMenuOpen = useSetCommandMenuOpen() const { openFeaturePreviewModal } = useFeaturePreviewModal() - const profileShowEmailEnabled = useIsFeatureEnabled('profile:show_email') - - const { username, primary_email } = profile ?? {} - - const { data, isLoading: isLoadingIdentities } = useProfileIdentitiesQuery() - const isGitHubProfile = profile?.auth0_id.startsWith('github') - const gitHubUsername = isGitHubProfile - ? (data?.identities ?? []).find((x) => x.provider === 'github')?.identity_data?.user_name - : undefined - const profileImageUrl = isGitHubProfile ? getGitHubProfileImgUrl(gitHubUsername) : undefined return ( - + @@ -72,17 +58,17 @@ export function UserDropdown() { {IS_PLATFORM && ( <>
- {profile && ( + {!!username && ( <> {username} - {primary_email !== username && profileShowEmailEnabled && ( + {primaryEmail !== username && profileShowEmailEnabled && ( - {primary_email} + {primaryEmail} )} diff --git a/apps/studio/components/ui/AIAssistantPanel/AIAssistant.tsx b/apps/studio/components/ui/AIAssistantPanel/AIAssistant.tsx index fc93abf135..757ee95ff9 100644 --- a/apps/studio/components/ui/AIAssistantPanel/AIAssistant.tsx +++ b/apps/studio/components/ui/AIAssistantPanel/AIAssistant.tsx @@ -18,16 +18,16 @@ import { useOrgAiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' import { useSelectedOrganizationQuery } from 'hooks/misc/useSelectedOrganization' import { useSelectedProjectQuery } from 'hooks/misc/useSelectedProject' import { useHotKey } from 'hooks/ui/useHotKey' +import { prepareMessagesForAPI } from 'lib/ai/message-utils' import { BASE_PATH, IS_PLATFORM } from 'lib/constants' import uuidv4 from 'lib/uuid' -import type { AssistantMessageType } from 'state/ai-assistant-state' import { useAiAssistantStateSnapshot } from 'state/ai-assistant-state' import { useSqlEditorV2StateSnapshot } from 'state/sql-editor-v2' import { Button, cn, KeyboardShortcut } from 'ui' import { Admonition } from 'ui-patterns' import { ButtonTooltip } from '../ButtonTooltip' import { ErrorBoundary } from '../ErrorBoundary' -import { type SqlSnippet } from './AIAssistant.types' +import type { SqlSnippet } from './AIAssistant.types' import { onErrorChat } from './AIAssistant.utils' import { AIAssistantHeader } from './AIAssistantHeader' import { AIOnboarding } from './AIOnboarding' @@ -37,7 +37,7 @@ import { ConversationContent, ConversationScrollButton, } from './elements/Conversation' -import { MemoizedMessage } from './Message' +import { Message } from './Message' interface AIAssistantProps { initialMessages?: MessageType[] | undefined @@ -107,16 +107,8 @@ export const AIAssistant = ({ className }: AIAssistantProps) => { const { mutate: sendEvent } = useSendEventMutation() const updateMessage = useCallback( - ({ - messageId, - resultId, - results, - }: { - messageId: string - resultId?: string - results: any[] - }) => { - snap.updateMessage({ id: messageId, resultId, results }) + (updatedMessage: MessageType) => { + snap.updateMessage(updatedMessage) }, [snap] ) @@ -128,10 +120,10 @@ export const AIAssistant = ({ className }: AIAssistantProps) => { snap.saveMessage([lastUserMessageRef.current, message]) lastUserMessageRef.current = null } else { - snap.saveMessage(message) + updateMessage(message) } }, - [snap] + [snap, updateMessage] ) // TODO(refactor): This useChat hook should be moved down into each chat session. @@ -189,21 +181,7 @@ export const AIAssistant = ({ className }: AIAssistantProps) => { transport: new DefaultChatTransport({ api: `${BASE_PATH}/api/ai/sql/generate-v4`, async prepareSendMessagesRequest({ messages, ...options }) { - // [Joshen] Specifically limiting the chat history that get's sent to reduce the - // size of the context that goes into the model. This should always be an odd number - // as much as possible so that the first message is always the user's - const MAX_CHAT_HISTORY = 7 - - const slicedMessages = messages.slice(-MAX_CHAT_HISTORY) - - // Filter out results from messages before sending to the model - const cleanedMessages = slicedMessages.map((message: any) => { - const cleanedMessage = { ...message } as AssistantMessageType - if (message.role === 'assistant' && (message as AssistantMessageType).results) { - delete cleanedMessage.results - } - return cleanedMessage - }) + const cleanedMessages = prepareMessagesForAPI(messages) const headerData = await constructHeaders() const authorizationHeader = headerData.get('Authorization') @@ -289,29 +267,33 @@ export const AIAssistant = ({ className }: AIAssistantProps) => { const isAfterEditedMessage = editingMessageId ? chatMessages.findIndex((m) => m.id === editingMessageId) < index : false + const isLastMessage = index === chatMessages.length - 1 return ( - ) }), [ chatMessages, - updateMessage, deleteMessageFromHere, editMessage, cancelEdit, editingMessageId, chatStatus, + addToolResult, ] ) diff --git a/apps/studio/components/ui/AIAssistantPanel/AIOnboarding.tsx b/apps/studio/components/ui/AIAssistantPanel/AIOnboarding.tsx index e8e8f69a41..d267fdd661 100644 --- a/apps/studio/components/ui/AIAssistantPanel/AIOnboarding.tsx +++ b/apps/studio/components/ui/AIAssistantPanel/AIOnboarding.tsx @@ -1,14 +1,13 @@ import { motion } from 'framer-motion' -import { partition } from 'lodash' import { BarChart, FileText, Shield } from 'lucide-react' -import { Button, Skeleton } from 'ui' import { useParams } from 'common' import { LINTER_LEVELS } from 'components/interfaces/Linter/Linter.constants' import { createLintSummaryPrompt } from 'components/interfaces/Linter/Linter.utils' -import { useProjectLintsQuery } from 'data/lint/lint-query' -import { type SqlSnippet } from './AIAssistant.types' +import { type Lint, useProjectLintsQuery } from 'data/lint/lint-query' +import { Button, Skeleton } from 'ui' import { codeSnippetPrompts, defaultPrompts } from './AIAssistant.prompts' +import { type SqlSnippet } from './AIAssistant.types' interface AIOnboardingProps { sqlSnippets?: SqlSnippet[] @@ -44,11 +43,10 @@ export const AIOnboarding = ({ } = useProjectLintsQuery({ projectRef }) const isLintsLoading = isLoadingLints || isFetchingLints - const errorLints = lints?.filter((lint) => lint.level === LINTER_LEVELS.ERROR) ?? [] - const [securityErrorLints, performanceErrorLints] = partition( - errorLints, - (lint) => lint.categories?.[0] === 'SECURITY' - ) + const errorLints: Lint[] = (lints?.filter((lint) => lint.level === LINTER_LEVELS.ERROR) ?? + []) as Lint[] + const securityErrorLints = errorLints.filter((lint) => lint.categories?.[0] === 'SECURITY') + const performanceErrorLints = errorLints.filter((lint) => lint.categories?.[0] !== 'SECURITY') return (
@@ -56,7 +54,7 @@ export const AIOnboarding = ({

How can I assist you?

{suggestions?.prompts?.length ? ( - <> +

Suggestions

{prompts.map((item, index) => ( ))} - +
) : ( <> {isLintsLoading ? ( @@ -139,7 +137,7 @@ export const AIOnboarding = ({ onFocusInput?.() }} > - {lint.detail ? lint.detail.replace(/\\`/g, '') : lint.title} + {lint.detail ? lint.detail.replace(/`/g, '') : lint.title} ) })} diff --git a/apps/studio/components/ui/AIAssistantPanel/ConfirmFooter.tsx b/apps/studio/components/ui/AIAssistantPanel/ConfirmFooter.tsx new file mode 100644 index 0000000000..8f79bddcd1 --- /dev/null +++ b/apps/studio/components/ui/AIAssistantPanel/ConfirmFooter.tsx @@ -0,0 +1,41 @@ +import { PropsWithChildren } from 'react' + +import { Button, cn } from 'ui' + +interface ConfirmFooterProps { + message: string + cancelLabel?: string + confirmLabel?: string + isLoading?: boolean + onCancel?: () => void | Promise + onConfirm?: () => void | Promise +} + +export const ConfirmFooter = ({ + message, + cancelLabel = 'Cancel', + confirmLabel = 'Confirm', + isLoading = false, + onCancel, + onConfirm, +}: PropsWithChildren) => { + return ( +
+
{message}
+
+ + +
+
+ ) +} diff --git a/apps/studio/components/ui/AIAssistantPanel/DisplayBlockRenderer.tsx b/apps/studio/components/ui/AIAssistantPanel/DisplayBlockRenderer.tsx index a36502f0bf..9d360738a2 100644 --- a/apps/studio/components/ui/AIAssistantPanel/DisplayBlockRenderer.tsx +++ b/apps/studio/components/ui/AIAssistantPanel/DisplayBlockRenderer.tsx @@ -1,51 +1,69 @@ import { PermissionAction } from '@supabase/shared-types/out/constants' -import type { UIDataTypes, UIMessagePart, UITools } from 'ai' import { useRouter } from 'next/router' -import { DragEvent, PropsWithChildren, useMemo, useState } from 'react' +import { type DragEvent, type PropsWithChildren, useRef, useState } from 'react' import { useParams } from 'common' import { ChartConfig } from 'components/interfaces/SQLEditor/UtilityPanel/ChartConfig' +import { usePrimaryDatabase } from 'data/read-replicas/replicas-query' +import { useExecuteSqlMutation } from 'data/sql/execute-sql-mutation' import { useSendEventMutation } from 'data/telemetry/send-event-mutation' +import { useChangedSync } from 'hooks/misc/useChanged' import { useAsyncCheckPermissions } from 'hooks/misc/useCheckPermissions' import { useSelectedOrganizationQuery } from 'hooks/misc/useSelectedOrganization' import { useProfile } from 'lib/profile' -import { useAiAssistantStateSnapshot } from 'state/ai-assistant-state' -import { Badge } from 'ui' import { DEFAULT_CHART_CONFIG, QueryBlock } from '../QueryBlock/QueryBlock' import { identifyQueryType } from './AIAssistant.utils' -import { findResultForManualId } from './Message.utils' +import { ConfirmFooter } from './ConfirmFooter' interface DisplayBlockRendererProps { messageId: string toolCallId: string - manualId?: string initialArgs: { sql: string label?: string + isWriteQuery?: boolean view?: 'table' | 'chart' xAxis?: string yAxis?: string - runQuery?: boolean } - messageParts: UIMessagePart[] | undefined - isLoading: boolean - onResults: (args: { messageId: string; resultId?: string; results: any[] }) => void + initialResults?: unknown + onResults?: (args: { messageId: string; results: unknown }) => void + onError?: (args: { messageId: string; errorText: string }) => void + toolState?: 'input-streaming' | 'input-available' | 'output-available' | 'output-error' + isLastPart?: boolean + isLastMessage?: boolean + showConfirmFooter?: boolean + onChartConfigChange?: (chartConfig: ChartConfig) => void + onQueryRun?: (queryType: 'select' | 'mutation') => void } export const DisplayBlockRenderer = ({ messageId, toolCallId, - manualId, initialArgs, - messageParts, - isLoading, + initialResults, onResults, + onError, + toolState, + isLastPart = false, + isLastMessage = false, + showConfirmFooter = true, + onChartConfigChange, + onQueryRun, }: PropsWithChildren) => { + const savedInitialArgs = useRef(initialArgs) + const savedInitialResults = useRef(initialResults) + const savedInitialConfig = useRef({ + ...DEFAULT_CHART_CONFIG, + view: initialArgs.view === 'chart' ? 'chart' : 'table', + xKey: initialArgs.xAxis ?? '', + yKey: initialArgs.yAxis ?? '', + }) + const router = useRouter() const { ref } = useParams() const { profile } = useProfile() const { data: org } = useSelectedOrganizationQuery() - const snap = useAiAssistantStateSnapshot() const { mutate: sendEvent } = useSendEventMutation() const { can: canCreateSQLSnippet } = useAsyncCheckPermissions( @@ -64,22 +82,49 @@ export const DisplayBlockRenderer = ({ yKey: initialArgs.yAxis ?? '', })) - const isChart = initialArgs.view === 'chart' - const resultId = manualId || toolCallId - const liveResultData = useMemo( - () => (manualId ? findResultForManualId(messageParts, manualId) : undefined), - [messageParts, manualId] + const [rows, setRows] = useState( + Array.isArray(initialResults) ? initialResults : undefined ) - const cachedResults = useMemo( - () => snap.getCachedSQLResults({ messageId, snippetId: resultId }), - [snap, messageId, resultId] - ) - const displayData = liveResultData ?? cachedResults const isDraggableToReports = canCreateSQLSnippet && router.pathname.endsWith('/reports/[id]') const label = initialArgs.label || 'SQL Results' + const [isWriteQuery, setIsWriteQuery] = useState(initialArgs.isWriteQuery || false) const sqlQuery = initialArgs.sql + const { database: primaryDatabase } = usePrimaryDatabase({ projectRef: ref }) + + const readOnlyConnectionString = primaryDatabase?.connection_string_read_only + const postgresConnectionString = primaryDatabase?.connectionString + + const { + mutate: executeSql, + error: executeSqlError, + isLoading: executeSqlLoading, + } = useExecuteSqlMutation({ + onError: () => { + // Suppress toast because error message is displayed inline + }, + }) + + const toolCallIdChanged = useChangedSync(toolCallId) + if (toolCallIdChanged) { + setChartConfig(savedInitialConfig.current) + onChartConfigChange?.(savedInitialConfig.current) + setIsWriteQuery(savedInitialArgs.current.isWriteQuery || false) + setRows(Array.isArray(savedInitialResults.current) ? savedInitialResults.current : undefined) + } + + const initialResultsChanged = useChangedSync(initialResults) + if (initialResultsChanged) { + const normalized = Array.isArray(initialResults) ? initialResults : undefined + if (!normalized || normalized === rows) return + setRows(normalized) + } + const handleRunQuery = (queryType: 'select' | 'mutation') => { + if (!sqlQuery) return + + onQueryRun?.(queryType) + sendEvent({ action: 'assistant_suggestion_run_query_clicked', properties: { @@ -93,12 +138,66 @@ export const DisplayBlockRenderer = ({ }) } + const runQuery = (queryType: 'select' | 'mutation') => { + if (!ref || !sqlQuery) return + + const connectionString = + queryType === 'mutation' + ? postgresConnectionString + : readOnlyConnectionString ?? postgresConnectionString + + if (!connectionString) { + const fallbackMessage = 'Unable to find a database connection to execute this query.' + onError?.({ messageId, errorText: fallbackMessage }) + return + } + + if (queryType === 'mutation') { + setIsWriteQuery(true) + } + executeSql( + { projectRef: ref, connectionString, sql: sqlQuery }, + { + onSuccess: (data) => { + setRows(Array.isArray(data.result) ? data.result : undefined) + setIsWriteQuery(queryType === 'mutation' || initialArgs.isWriteQuery || false) + onResults?.({ + messageId, + results: Array.isArray(data.result) ? data.result : undefined, + }) + }, + onError: (error) => { + const lowerMessage = error.message.toLowerCase() + const isReadOnlyError = + lowerMessage.includes('read-only transaction') || + lowerMessage.includes('permission denied') || + lowerMessage.includes('must be owner') + + if (queryType === 'select' && isReadOnlyError) { + setIsWriteQuery(true) + } + + onError?.({ messageId, errorText: error.message }) + }, + } + ) + } + + const handleExecute = (queryType: 'select' | 'mutation') => { + handleRunQuery(queryType) + runQuery(queryType) + } + const handleUpdateChartConfig = ({ chartConfig: updatedValues, }: { chartConfig: Partial }) => { - setChartConfig((prev) => ({ ...prev, ...updatedValues })) + setChartConfig((prev) => { + const next = { ...prev, ...updatedValues } + onChartConfigChange?.(next) + return next + }) } const handleDragStart = (e: DragEvent) => { @@ -108,35 +207,48 @@ export const DisplayBlockRenderer = ({ ) } + const resolvedHasDecision = initialResults !== undefined || rows !== undefined + const shouldShowConfirmFooter = + showConfirmFooter && + !resolvedHasDecision && + toolState === 'input-available' && + isLastPart && + isLastMessage + return ( -
- - - NEW - -

Drag to add this chart into your custom report

-
- ) : undefined - } - onResults={(results) => onResults({ messageId, resultId, results })} - onRunQuery={handleRunQuery} - onUpdateChartConfig={handleUpdateChartConfig} - onDragStart={handleDragStart} - /> +
+
+ +
+ {shouldShowConfirmFooter && ( +
+ { + onResults?.({ messageId, results: 'User skipped running the query' }) + }} + onConfirm={() => { + handleExecute(isWriteQuery ? 'mutation' : 'select') + }} + /> +
+ )}
) } diff --git a/apps/studio/components/ui/AIAssistantPanel/EdgeFunctionRenderer.tsx b/apps/studio/components/ui/AIAssistantPanel/EdgeFunctionRenderer.tsx new file mode 100644 index 0000000000..bf9ddabd1e --- /dev/null +++ b/apps/studio/components/ui/AIAssistantPanel/EdgeFunctionRenderer.tsx @@ -0,0 +1,158 @@ +import { type PropsWithChildren, useMemo, useState } from 'react' +import { toast } from 'sonner' + +import { useParams } from 'common' +import { useProjectSettingsV2Query } from 'data/config/project-settings-v2-query' +import { useEdgeFunctionQuery } from 'data/edge-functions/edge-function-query' +import { useEdgeFunctionDeployMutation } from 'data/edge-functions/edge-functions-deploy-mutation' +import { useSendEventMutation } from 'data/telemetry/send-event-mutation' +import { useSelectedOrganizationQuery } from 'hooks/misc/useSelectedOrganization' +import { EdgeFunctionBlock } from '../EdgeFunctionBlock/EdgeFunctionBlock' +import { ConfirmFooter } from './ConfirmFooter' + +interface EdgeFunctionRendererProps { + label: string + code: string + functionName: string + onDeployed?: (result: { success: true } | { success: false; errorText: string }) => void + initialIsDeployed?: boolean + showConfirmFooter?: boolean +} + +export const EdgeFunctionRenderer = ({ + label, + code, + functionName, + onDeployed, + initialIsDeployed, + showConfirmFooter = true, +}: PropsWithChildren) => { + const { ref } = useParams() + const { data: org } = useSelectedOrganizationQuery() + const { mutate: sendEvent } = useSendEventMutation() + const [isDeployed, setIsDeployed] = useState(!!initialIsDeployed) + const [showReplaceWarning, setShowReplaceWarning] = useState(false) + + const { data: settings } = useProjectSettingsV2Query({ projectRef: ref }, { enabled: !!ref }) + const { data: existingFunction } = useEdgeFunctionQuery( + { projectRef: ref, slug: functionName }, + { enabled: !!ref && !!functionName } + ) + + const { + mutate: deployFunction, + error: deployError, + isLoading: isDeploying, + } = useEdgeFunctionDeployMutation({ + onSuccess: () => { + setIsDeployed(true) + toast.success('Successfully deployed edge function') + onDeployed?.({ success: true }) + }, + onError: (error) => { + const errMsg = error?.message ?? 'Unknown error' + const message = `Failed to deploy function: ${errMsg}` + toast.error(message) + setIsDeployed(false) + onDeployed?.({ success: false, errorText: errMsg }) + }, + }) + + const functionUrl = useMemo(() => { + const endpoint = settings?.app_config?.endpoint + if (!endpoint || !ref || !functionName) return undefined + + try { + const url = new URL(`https://${endpoint}`) + const restUrlTld = url.hostname.split('.').pop() + return restUrlTld + ? `https://${ref}.supabase.${restUrlTld}/functions/v1/${functionName}` + : undefined + } catch (error) { + return undefined + } + }, [settings?.app_config?.endpoint, ref, functionName]) + + const deploymentDetailsUrl = useMemo(() => { + if (!ref || !functionName) return undefined + return `/project/${ref}/functions/${functionName}/details` + }, [ref, functionName]) + + const downloadCommand = useMemo(() => { + if (!functionName) return undefined + return `supabase functions download ${functionName}` + }, [functionName]) + + const performDeploy = async () => { + if (!ref || !functionName || !code) return + + deployFunction({ + projectRef: ref, + slug: functionName, + metadata: { + entrypoint_path: 'index.ts', + name: functionName, + verify_jwt: true, + }, + files: [{ name: 'index.ts', content: code }], + }) + + sendEvent({ + action: 'edge_function_deploy_button_clicked', + properties: { origin: 'functions_ai_assistant' }, + groups: { + project: ref ?? 'Unknown', + organization: org?.slug ?? 'Unknown', + }, + }) + + setShowReplaceWarning(false) + } + + const handleDeploy = () => { + if (!code || isDeploying || !ref) return + + if (existingFunction) { + setShowReplaceWarning(true) + return + } + + void performDeploy() + } + + return ( +
+ setShowReplaceWarning(false)} + onConfirmReplace={() => void performDeploy()} + onDeploy={handleDeploy} + hideDeployButton={showConfirmFooter} + /> + {showConfirmFooter && ( +
+ { + onDeployed?.({ success: false, errorText: 'Skipped' }) + }} + onConfirm={() => handleDeploy()} + /> +
+ )} +
+ ) +} diff --git a/apps/studio/components/ui/AIAssistantPanel/Message.Actions.tsx b/apps/studio/components/ui/AIAssistantPanel/Message.Actions.tsx new file mode 100644 index 0000000000..140f8c0880 --- /dev/null +++ b/apps/studio/components/ui/AIAssistantPanel/Message.Actions.tsx @@ -0,0 +1,46 @@ +import { Pencil, Trash2 } from 'lucide-react' +import { type PropsWithChildren } from 'react' + +import { ButtonTooltip } from '../ButtonTooltip' + +export function MessageActions({ children }: PropsWithChildren<{}>) { + return ( +
+ +
{children}
+
+ ) +} +function MessageActionsEdit({ onClick, tooltip }: { onClick: () => void; tooltip: string }) { + return ( + } + onClick={onClick} + className="text-foreground-light hover:text-foreground p-1 rounded" + aria-label={tooltip} + tooltip={{ + content: { + side: 'bottom', + text: tooltip, + }, + }} + /> + ) +} +MessageActions.Edit = MessageActionsEdit + +function MessageActionsDelete({ onClick }: { onClick: () => void }) { + return ( + } + tooltip={{ content: { side: 'bottom', text: 'Delete message' } }} + onClick={onClick} + className="text-foreground-light hover:text-foreground p-1 rounded" + title="Delete message" + aria-label="Delete message" + /> + ) +} +MessageActions.Delete = MessageActionsDelete diff --git a/apps/studio/components/ui/AIAssistantPanel/Message.Context.tsx b/apps/studio/components/ui/AIAssistantPanel/Message.Context.tsx new file mode 100644 index 0000000000..70a2ff87c8 --- /dev/null +++ b/apps/studio/components/ui/AIAssistantPanel/Message.Context.tsx @@ -0,0 +1,62 @@ +import { createContext, type PropsWithChildren, useContext } from 'react' + +export type AddToolResult = (args: { + tool: string + toolCallId: string + output: unknown +}) => Promise + +export interface MessageInfo { + id: string + + variant?: 'default' | 'warning' + + isLoading: boolean + readOnly?: boolean + + isUserMessage?: boolean + isLastMessage?: boolean + + state: 'idle' | 'editing' | 'predecessor-editing' +} + +export interface MessageActions { + addToolResult?: AddToolResult + + onDelete: (id: string) => void + onEdit: (id: string) => void + onCancelEdit: () => void +} + +const MessageInfoContext = createContext(null) +const MessageActionsContext = createContext(null) + +export function useMessageInfoContext() { + const ctx = useContext(MessageInfoContext) + if (!ctx) { + throw Error('useMessageInfoContext must be used within a MessageProvider') + } + return ctx +} + +export function useMessageActionsContext() { + const ctx = useContext(MessageActionsContext) + if (!ctx) { + throw Error('useMessageActionsContext must be used within a MessageProvider') + } + return ctx +} + +export function MessageProvider({ + messageInfo, + messageActions, + children, +}: PropsWithChildren<{ messageInfo: MessageInfo; messageActions: MessageActions }>) { + return ( + + + {children} + + + ) +} diff --git a/apps/studio/components/ui/AIAssistantPanel/Message.Display.tsx b/apps/studio/components/ui/AIAssistantPanel/Message.Display.tsx new file mode 100644 index 0000000000..5f4a95e88d --- /dev/null +++ b/apps/studio/components/ui/AIAssistantPanel/Message.Display.tsx @@ -0,0 +1,90 @@ +import { UIMessage as VercelMessage } from '@ai-sdk/react' +import { type PropsWithChildren } from 'react' + +import { ProfileImage as ProfileImageDisplay } from 'components/ui/ProfileImage' +import { useProfileNameAndPicture } from 'lib/profile' +import { cn } from 'ui' +import { useMessageInfoContext } from './Message.Context' +import { MessageMarkdown, MessagePartSwitcher } from './Message.Parts' + +function MessageDisplayProfileImage() { + const { username, avatarUrl } = useProfileNameAndPicture() + return ( + + ) +} + +function MessageDisplayContainer({ + children, + onClick, + className, +}: PropsWithChildren<{ onClick?: () => void; className?: string }>) { + return ( +
+ {children} +
+ ) +} + +function MessageDisplayMainArea({ + children, + className, +}: PropsWithChildren<{ className?: string }>) { + return
{children}
+} + +function MessageDisplayContent({ message }: { message: VercelMessage }) { + const { id, isLoading, readOnly } = useMessageInfoContext() + + const messageParts = message.parts + const content = + ('content' in message && typeof message.content === 'string' && message.content.trim()) || + undefined + + return ( +
+ {messageParts?.length > 0 + ? messageParts.map((part: NonNullable, idx) => { + const isLastPart = idx === messageParts.length - 1 + return + }) + : content && ( + + {content} + + )} +
+ ) +} + +function MessageDisplayTextMessage({ + id, + isLoading, + readOnly, + children, +}: PropsWithChildren<{ id: string; isLoading: boolean; readOnly?: boolean }>) { + return ( + + {children} + + ) +} + +export const MessageDisplay = { + Container: MessageDisplayContainer, + Content: MessageDisplayContent, + MainArea: MessageDisplayMainArea, + ProfileImage: MessageDisplayProfileImage, +} diff --git a/apps/studio/components/ui/AIAssistantPanel/Message.Parts.tsx b/apps/studio/components/ui/AIAssistantPanel/Message.Parts.tsx new file mode 100644 index 0000000000..994d15b224 --- /dev/null +++ b/apps/studio/components/ui/AIAssistantPanel/Message.Parts.tsx @@ -0,0 +1,327 @@ +import { UIMessage as VercelMessage } from '@ai-sdk/react' +import { type DynamicToolUIPart, type ReasoningUIPart, type TextUIPart, type ToolUIPart } from 'ai' +import { BrainIcon, CheckIcon, Loader2 } from 'lucide-react' +import { useMemo, type PropsWithChildren } from 'react' +import ReactMarkdown from 'react-markdown' +import { type Components } from 'react-markdown/lib/ast-to-react' +import remarkGfm from 'remark-gfm' + +import { cn, markdownComponents } from 'ui' +import { DisplayBlockRenderer } from './DisplayBlockRenderer' +import { EdgeFunctionRenderer } from './EdgeFunctionRenderer' +import { Tool } from './elements/Tool' +import { useMessageActionsContext, useMessageInfoContext } from './Message.Context' +import { + deployEdgeFunctionInputSchema, + deployEdgeFunctionOutputSchema, + parseExecuteSqlChartResult, +} from './Message.utils' +import { + Heading3, + Hyperlink, + InlineCode, + ListItem, + MarkdownPre, + OrderedList, +} from './MessageMarkdown' + +const baseMarkdownComponents: Partial = { + ol: OrderedList, + li: ListItem, + h3: Heading3, + code: InlineCode, + a: Hyperlink, + img: ({ src }) => [Image: {src}], +} + +export function MessageMarkdown({ + id, + isLoading, + readOnly, + className, + children, +}: PropsWithChildren<{ + id: string + isLoading: boolean + readOnly?: boolean + className?: string +}>) { + const markdownSource = useMemo(() => { + if (typeof children === 'string') { + return children + } + + if (Array.isArray(children)) { + return children.filter((child): child is string => typeof child === 'string').join('') + } + + return '' + }, [children]) + + const allMarkdownComponents: Partial = useMemo( + () => ({ + ...markdownComponents, + ...baseMarkdownComponents, + pre: ({ children }) => ( + + {children} + + ), + }), + [id, isLoading, readOnly] + ) + + return ( + + {markdownSource} + + ) +} + +function MessagePartText({ textPart }: { textPart: TextUIPart }) { + const { id, isLoading, readOnly, isUserMessage, state } = useMessageInfoContext() + + return ( + div]:my-4 prose-h1:text-xl prose-h1:mt-6 prose-h2:text-lg prose-h3:no-underline prose-h3:text-base prose-h3:mb-4 prose-strong:font-medium prose-strong:text-foreground prose-ol:space-y-3 prose-ul:space-y-3 prose-li:my-0 break-words [&>p:not(:last-child)]:!mb-2 [&>*>p:first-child]:!mt-0 [&>*>p:last-child]:!mb-0 [&>*>*>p:first-child]:!mt-0 [&>*>*>p:last-child]:!mb-0 [&>ol>li]:!pl-4', + isUserMessage && 'text-foreground [&>p]:font-medium', + state === 'editing' && 'animate-pulse' + )} + > + {textPart.text} + + ) +} + +function MessagePartDynamicTool({ toolPart }: { toolPart: DynamicToolUIPart }) { + return ( + + ) : ( + + ) + } + label={ +
+ {toolPart.state === 'input-streaming' ? 'Running ' : 'Ran '} + {`${toolPart.toolName}`} +
+ } + /> + ) +} + +function MessagePartTool({ toolPart }: { toolPart: ToolUIPart }) { + return ( + + ) : ( + + ) + } + label={ +
+ {toolPart.state === 'input-streaming' ? 'Running ' : 'Ran '} + {`${toolPart.type.replace('tool-', '')}`} +
+ } + /> + ) +} + +function MessagePartReasoning({ reasoningPart }: { reasoningPart: ReasoningUIPart }) { + return ( + + ) : ( + + ) + } + label={reasoningPart.state === 'streaming' ? 'Thinking...' : 'Reasoned'} + > + {reasoningPart.text} + + ) +} + +function ToolDisplayExecuteSqlLoading() { + return ( +
+ + Writing SQL... +
+ ) +} + +function ToolDisplayExecuteSqlFailure() { + return
Failed to execute SQL.
+} + +function MessagePartExecuteSql({ + toolPart, + isLastPart, +}: { + toolPart: ToolUIPart + isLastPart?: boolean +}) { + const { id, isLastMessage } = useMessageInfoContext() + const { addToolResult } = useMessageActionsContext() + + const { toolCallId, state, input, output } = toolPart + + if (state === 'input-streaming') { + return + } + + if (state === 'output-error') { + return + } + + const { data: chart, success } = parseExecuteSqlChartResult(input) + if (!success) return null + + if (state === 'input-available' || state === 'output-available') { + return ( +
+ { + const results = args.results as any[] + + addToolResult?.({ + tool: 'execute_sql', + toolCallId: String(toolCallId), + output: results, + }) + }} + onError={({ errorText }) => { + addToolResult?.({ + tool: 'execute_sql', + toolCallId: String(toolCallId), + output: `Error: ${errorText}`, + }) + }} + /> +
+ ) + } + + return null +} + +const TOOL_DEPLOY_EDGE_FUNCTION_STATES_WITH_INPUT = new Set(['input-available', 'output-available']) + +function MessagePartDeployEdgeFunction({ toolPart }: { toolPart: ToolUIPart }) { + const { toolCallId, state, input, output } = toolPart + const { addToolResult } = useMessageActionsContext() + + if (state === 'input-streaming') { + return ( +
+ + Writing Edge Function... +
+ ) + } + + if (state === 'output-error') { + return

Failed to deploy Edge Function.

+ } + + if (!TOOL_DEPLOY_EDGE_FUNCTION_STATES_WITH_INPUT.has(state)) return null + + const parsedInput = deployEdgeFunctionInputSchema.safeParse(input) + if (!parsedInput.success) return null + + const parsedOutput = deployEdgeFunctionOutputSchema.safeParse(output) + const isInitiallyDeployed = + state === 'output-available' && parsedOutput.success && parsedOutput.data.success === true + + return ( + { + addToolResult?.({ + tool: 'deploy_edge_function', + toolCallId: String(toolCallId), + output: result, + }) + }} + /> + ) +} + +const MessagePart = { + Text: MessagePartText, + Dynamic: MessagePartDynamicTool, + Tool: MessagePartTool, + Reasoning: MessagePartReasoning, + ExecuteSql: MessagePartExecuteSql, + DeployEdgeFunction: MessagePartDeployEdgeFunction, +} as const + +export function MessagePartSwitcher({ + part, + isLastPart, +}: { + part: NonNullable[number] + isLastPart?: boolean +}) { + switch (part.type) { + case 'dynamic-tool': { + return + } + case 'tool-list_policies': + case 'tool-search_docs': { + return + } + case 'reasoning': + return + case 'text': + return + + case 'tool-execute_sql': { + return + } + case 'tool-deploy_edge_function': { + return + } + + case 'source-url': + case 'source-document': + case 'file': + default: + return null + } +} diff --git a/apps/studio/components/ui/AIAssistantPanel/Message.tsx b/apps/studio/components/ui/AIAssistantPanel/Message.tsx index b06015320b..61c3f6f2fc 100644 --- a/apps/studio/components/ui/AIAssistantPanel/Message.tsx +++ b/apps/studio/components/ui/AIAssistantPanel/Message.tsx @@ -1,317 +1,60 @@ import { UIMessage as VercelMessage } from '@ai-sdk/react' -import { CheckIcon, Loader2, Pencil, Trash2 } from 'lucide-react' -import { createContext, memo, PropsWithChildren, ReactNode, useMemo, useState } from 'react' -import ReactMarkdown from 'react-markdown' -import { Components } from 'react-markdown/lib/ast-to-react' -import remarkGfm from 'remark-gfm' +import { useState } from 'react' import { toast } from 'sonner' -import { ProfileImage } from 'components/ui/ProfileImage' -import { useProfile } from 'lib/profile' -import { cn, markdownComponents, WarningIcon } from 'ui' -import { ButtonTooltip } from '../ButtonTooltip' -import { EdgeFunctionBlock } from '../EdgeFunctionBlock/EdgeFunctionBlock' +import { cn } from 'ui' import { DeleteMessageConfirmModal } from './DeleteMessageConfirmModal' -import { DisplayBlockRenderer } from './DisplayBlockRenderer' -import { - Heading3, - Hyperlink, - InlineCode, - ListItem, - MarkdownPre, - OrderedList, -} from './MessageMarkdown' -import { Reasoning } from './elements/Reasoning' +import { MessageActions } from './Message.Actions' +import type { AddToolResult, MessageInfo } from './Message.Context' +import { MessageDisplay } from './Message.Display' +import { MessageProvider, useMessageActionsContext, useMessageInfoContext } from './Message.Context' -interface MessageContextType { - isLoading: boolean - readOnly?: boolean -} -export const MessageContext = createContext({ isLoading: false }) - -const baseMarkdownComponents: Partial = { - ol: OrderedList, - li: ListItem, - h3: Heading3, - code: InlineCode, - a: Hyperlink, - img: ({ src }) => [Image: {src}], -} - -interface MessageProps { - id: string - message: VercelMessage - isLoading: boolean - readOnly?: boolean - status?: string - action?: ReactNode - variant?: 'default' | 'warning' - onResults: ({ - messageId, - resultId, - results, - }: { - messageId: string - resultId?: string - results: any[] - }) => void - onDelete: (id: string) => void - onEdit: (id: string) => void - isAfterEditedMessage: boolean - isBeingEdited: boolean - onCancelEdit: () => void -} - -const Message = function Message({ - id, - message, - isLoading, - readOnly, - action = null, - variant = 'default', - onResults, - onDelete, - onEdit, - isAfterEditedMessage = false, - isBeingEdited = false, - status, - onCancelEdit, -}: PropsWithChildren) { - const { profile } = useProfile() - const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false) - const allMarkdownComponents: Partial = useMemo( - () => ({ - ...markdownComponents, - ...baseMarkdownComponents, - pre: ({ children }) => ( - - {children} - - ), - }), - [id, onResults] - ) - - if (!message) { - console.error(`Message component received undefined message prop for id: ${id}`) - return null - } - - // For backwards compatibility: some stored messages may have a 'content' property - const { role, parts } = message - const hasContent = (msg: VercelMessage): msg is VercelMessage & { content: string } => - 'content' in msg && typeof msg.content === 'string' - const content = hasContent(message) ? message.content : undefined - const isUser = role === 'user' - - const shouldUsePartsRendering = parts && parts.length > 0 - - const hasTextContent = content && content.trim().length > 0 +function AssistantMessage({ message }: { message: VercelMessage }) { + const { variant, state } = useMessageInfoContext() + const { onCancelEdit } = useMessageActionsContext() return ( - -
+ + + + + ) +} + +function UserMessage({ message }: { message: VercelMessage }) { + const { id, variant, state } = useMessageInfoContext() + const { onCancelEdit, onEdit, onDelete } = useMessageActionsContext() + const [showDeleteConfirmModal, setShowDeleteConfirmModal] = useState(false) + + return ( + <> + - {variant === 'warning' && } - - {action} - -
- {isUser && ( - - )} - -
- {shouldUsePartsRendering ? ( - (() => { - return parts.map( - (part: NonNullable[number], index: number) => { - switch (part.type) { - case 'dynamic-tool': { - return ( -
- {part.state === 'input-streaming' ? ( - - ) : ( - - )} - {`${part.toolName}`} -
- ) - } - case 'reasoning': - return ( - - {part.text} - - ) - case 'text': - return ( - div]:my-4 prose-h1:text-xl prose-h1:mt-6 prose-h2:text-lg prose-h3:no-underline prose-h3:text-base prose-h3:mb-4 prose-strong:font-medium prose-strong:text-foreground prose-ol:space-y-3 prose-ul:space-y-3 prose-li:my-0 break-words [&>p:not(:last-child)]:!mb-2 [&>*>p:first-child]:!mt-0 [&>*>p:last-child]:!mb-0 [&>*>*>p:first-child]:!mt-0 [&>*>*>p:last-child]:!mb-0 [&>ol>li]:!pl-4', - isUser && 'text-foreground [&>p]:font-medium', - isBeingEdited && 'animate-pulse' - )} - remarkPlugins={[remarkGfm]} - components={allMarkdownComponents} - > - {part.text} - - ) - - case 'tool-display_query': { - const { toolCallId, state, input } = part - if (state === 'input-streaming' || state === 'input-available') { - return ( -
- - {`Calling display_query...`} -
- ) - } - if (state === 'output-available') { - return ( - - ) - } - return null - } - case 'tool-display_edge_function': { - const { toolCallId, state, input } = part - if (state === 'input-streaming' || state === 'input-available') { - return ( -
- - {`Calling display_edge_function...`} -
- ) - } - if (state === 'output-available') { - return ( -
- -
- ) - } - return null - } - case 'source-url': - case 'source-document': - case 'file': - return null - default: - return null - } - } - ) - })() - ) : hasTextContent ? ( - - {content} - - ) : ( - Assistant is thinking... - )} - - {/* Action button - only show for user messages on hover */} -
- {message.role === 'user' && ( - <> - } - onClick={ - isBeingEdited || isAfterEditedMessage ? onCancelEdit : () => onEdit(id) - } - className="text-foreground-light hover:text-foreground p-1 rounded" - aria-label={ - isBeingEdited || isAfterEditedMessage ? 'Cancel editing' : 'Edit message' - } - tooltip={{ - content: { - side: 'bottom', - text: - isBeingEdited || isAfterEditedMessage ? 'Cancel editing' : 'Edit message', - }, - }} - /> - - } - tooltip={{ content: { side: 'bottom', text: 'Delete message' } }} - onClick={() => setShowDeleteConfirmModal(true)} - className="text-foreground-light hover:text-foreground p-1 rounded" - title="Delete message" - aria-label="Delete message" - /> - - )} -
-
-
-
- + + + + + + onEdit(id) : onCancelEdit} + tooltip={state === 'idle' ? 'Edit message' : 'Cancel editing'} + /> + setShowDeleteConfirmModal(true)} /> + + { @@ -321,54 +64,53 @@ const Message = function Message({ }} onCancel={() => setShowDeleteConfirmModal(false)} /> -
+ ) } -export const MemoizedMessage = memo( - ({ - message, - status, - onResults, - onDelete, - onEdit, - isAfterEditedMessage, - isBeingEdited, - onCancelEdit, - }: { - message: VercelMessage - status: string - onResults: ({ - messageId, - resultId, - results, - }: { - messageId: string - resultId?: string - results: any[] - }) => void - onDelete: (id: string) => void - onEdit: (id: string) => void - isAfterEditedMessage: boolean - isBeingEdited: boolean - onCancelEdit: () => void - }) => { - return ( - - ) - } -) +interface MessageProps { + id: string + message: VercelMessage + isLoading: boolean + readOnly?: boolean + variant?: 'default' | 'warning' + addToolResult?: AddToolResult + onDelete: (id: string) => void + onEdit: (id: string) => void + isAfterEditedMessage: boolean + isBeingEdited: boolean + onCancelEdit: () => void + isLastMessage?: boolean +} -MemoizedMessage.displayName = 'MemoizedMessage' +export function Message(props: MessageProps) { + const message = props.message + const { role } = message + const isUserMessage = role === 'user' + + const messageInfo = { + id: props.id, + isLoading: props.isLoading, + readOnly: props.readOnly, + variant: props.variant, + state: props.isBeingEdited + ? 'editing' + : props.isAfterEditedMessage + ? 'predecessor-editing' + : 'idle', + isLastMessage: props.isLastMessage, + } satisfies MessageInfo + + const messageActions = { + addToolResult: props.addToolResult, + onDelete: props.onDelete, + onEdit: props.onEdit, + onCancelEdit: props.onCancelEdit, + } + + return ( + + {isUserMessage ? : } + + ) +} diff --git a/apps/studio/components/ui/AIAssistantPanel/Message.utils.ts b/apps/studio/components/ui/AIAssistantPanel/Message.utils.ts index b574b6b804..cf80eb9572 100644 --- a/apps/studio/components/ui/AIAssistantPanel/Message.utils.ts +++ b/apps/studio/components/ui/AIAssistantPanel/Message.utils.ts @@ -1,53 +1,4 @@ -const extractDataFromSafetyMessage = (text: string): string | null => { - const openingTags = [...text.matchAll(//gi)] - if (openingTags.length < 2) return null - - const closingTagMatch = text.match(/<\/untrusted-data-[a-z0-9-]+>/i) - if (!closingTagMatch) return null - - const secondOpeningEnd = openingTags[1].index! + openingTags[1][0].length - const closingStart = text.indexOf(closingTagMatch[0]) - const content = text.substring(secondOpeningEnd, closingStart) - - return content.replace(/\\n/g, '').replace(/\\"/g, '"').replace(/\n/g, '').trim() -} - -// Helper function to find result data directly from parts array -export const findResultForManualId = ( - parts: any[] | undefined, - manualId: string -): any[] | undefined => { - if (!parts) return undefined - - const invocationPart = parts.find( - (part) => - part.type === 'tool-invocation' && - 'toolInvocation' in part && - part.toolInvocation.state === 'result' && - 'result' in part.toolInvocation && - part.toolInvocation.result?.manualToolCallId === manualId - ) - - if ( - invocationPart && - 'toolInvocation' in invocationPart && - 'result' in invocationPart.toolInvocation && - invocationPart.toolInvocation.result?.content?.[0]?.text - ) { - try { - const rawText = invocationPart.toolInvocation.result.content[0].text - - const extractedData = extractDataFromSafetyMessage(rawText) || rawText - - let parsedData = JSON.parse(extractedData.trim()) - return Array.isArray(parsedData) ? parsedData : undefined - } catch (error) { - console.error('Failed to parse tool invocation result data for manualId:', manualId, error) - return undefined - } - } - return undefined -} +import { type SafeParseReturnType, z } from 'zod' // [Joshen] From https://github.com/remarkjs/react-markdown/blob/fda7fa560bec901a6103e195f9b1979dab543b17/lib/index.js#L425 export function defaultUrlTransform(value: string) { @@ -72,3 +23,76 @@ export function defaultUrlTransform(value: string) { return '' } + +const chartArgsSchema = z + .object({ + view: z.enum(['table', 'chart']).optional(), + xKey: z.string().optional(), + xAxis: z.string().optional(), + yKey: z.string().optional(), + yAxis: z.string().optional(), + }) + .passthrough() + +const chartArgsFieldSchema = z.preprocess((value) => { + if (!value || typeof value !== 'object') return undefined + if (Array.isArray(value)) return value[0] + return value +}, chartArgsSchema.optional()) + +const executeSqlChartResultSchema = z + .object({ + sql: z.string().optional(), + label: z.string().optional(), + isWriteQuery: z.boolean().optional(), + chartConfig: chartArgsFieldSchema, + config: chartArgsFieldSchema, + }) + .passthrough() + .transform(({ sql, label, isWriteQuery, chartConfig, config }) => { + const chartArgs = chartConfig ?? config + + return { + sql: sql ?? '', + label, + isWriteQuery, + view: chartArgs?.view, + xAxis: chartArgs?.xKey ?? chartArgs?.xAxis, + yAxis: chartArgs?.yKey ?? chartArgs?.yAxis, + } + }) + +export function parseExecuteSqlChartResult( + input: unknown +): SafeParseReturnType> { + return executeSqlChartResultSchema.safeParse(input) +} + +export const deployEdgeFunctionInputSchema = z + .object({ + code: z.string().min(1), + name: z.string().trim().optional(), + slug: z.string().trim().optional(), + functionName: z.string().trim().optional(), + label: z.string().optional(), + }) + .passthrough() + .transform((data) => { + const rawName = data.functionName ?? data.name ?? data.slug + const trimmedName = rawName?.trim() + const functionName = trimmedName && trimmedName.length > 0 ? trimmedName : 'my-function' + + const rawLabel = data.label ?? rawName + const trimmedLabel = rawLabel?.trim() + const label = trimmedLabel && trimmedLabel.length > 0 ? trimmedLabel : 'Edge Function' + + return { + code: data.code, + functionName, + label, + } + }) + +export const deployEdgeFunctionOutputSchema = z + .object({ success: z.boolean().optional() }) + .passthrough() diff --git a/apps/studio/components/ui/AIAssistantPanel/MessageMarkdown.tsx b/apps/studio/components/ui/AIAssistantPanel/MessageMarkdown.tsx index 7764136b47..c929a27441 100644 --- a/apps/studio/components/ui/AIAssistantPanel/MessageMarkdown.tsx +++ b/apps/studio/components/ui/AIAssistantPanel/MessageMarkdown.tsx @@ -1,27 +1,9 @@ -import { PermissionAction } from '@supabase/shared-types/out/constants' -import { useRouter } from 'next/router' -import { - DragEvent, - memo, - ReactNode, - useCallback, - useContext, - useEffect, - useMemo, - useRef, -} from 'react' +import { Loader2 } from 'lucide-react' +import Link from 'next/link' +import { memo, ReactNode, useEffect, useMemo, useRef } from 'react' import { ChartConfig } from 'components/interfaces/SQLEditor/UtilityPanel/ChartConfig' -import { useSendEventMutation } from 'data/telemetry/send-event-mutation' -import { useAsyncCheckPermissions } from 'hooks/misc/useCheckPermissions' -import { useSelectedOrganizationQuery } from 'hooks/misc/useSelectedOrganization' -import { useSelectedProjectQuery } from 'hooks/misc/useSelectedProject' -import { useProfile } from 'lib/profile' -import Link from 'next/link' -import { useAiAssistantStateSnapshot } from 'state/ai-assistant-state' -import { Dashboards } from 'types' import { - Badge, Button, cn, CodeBlock, @@ -35,13 +17,10 @@ import { DialogTitle, DialogTrigger, } from 'ui' -import { DebouncedComponent } from '../DebouncedComponent' import { EdgeFunctionBlock } from '../EdgeFunctionBlock/EdgeFunctionBlock' -import { QueryBlock } from '../QueryBlock/QueryBlock' import { AssistantSnippetProps } from './AIAssistant.types' -import { identifyQueryType } from './AIAssistant.utils' import { CollapsibleCodeBlock } from './CollapsibleCodeBlock' -import { MessageContext } from './Message' +import { DisplayBlockRenderer } from './DisplayBlockRenderer' import { defaultUrlTransform } from './Message.utils' export const OrderedList = memo(({ children }: { children: ReactNode }) => ( @@ -124,123 +103,17 @@ export const Hyperlink = memo(({ href, children }: { href?: string; children: Re }) Hyperlink.displayName = 'Hyperlink' -const MemoizedQueryBlock = memo( - ({ - sql, - title, - xAxis, - yAxis, - isChart, - isLoading, - isDraggable, - runQuery, - results, - onRunQuery, - onResults, - onDragStart, - onUpdateChartConfig, - }: { - sql: string - title: string - xAxis?: string - yAxis?: string - isChart: boolean - isLoading: boolean - isDraggable: boolean - runQuery: boolean - results?: any[] - onRunQuery: (queryType: 'select' | 'mutation') => void - onResults: (results: any[]) => void - onDragStart: (e: DragEvent) => void - onUpdateChartConfig?: ({ - chart, - chartConfig, - }: { - chart?: Partial - chartConfig: Partial - }) => void - }) => ( - - Writing SQL... -
- } - > - - - NEW - -

Drag to add this chart into your custom report

-
- ) : undefined - } - showSql={!isChart} - isChart={isChart} - isLoading={isLoading} - draggable={isDraggable} - runQuery={runQuery} - results={results} - onRunQuery={onRunQuery} - onResults={onResults} - onDragStart={onDragStart} - onUpdateChartConfig={onUpdateChartConfig} - /> - - ) -) -MemoizedQueryBlock.displayName = 'MemoizedQueryBlock' - export const MarkdownPre = ({ children, id, - onResults, + isLoading, + readOnly, }: { children: any id: string - onResults: ({ - messageId, - resultId, - results, - }: { - messageId: string - resultId?: string - results: any[] - }) => void + isLoading: boolean + readOnly?: boolean }) => { - const router = useRouter() - const { profile } = useProfile() - const { isLoading, readOnly } = useContext(MessageContext) - const { mutate: sendEvent } = useSendEventMutation() - const snap = useAiAssistantStateSnapshot() - const { data: project } = useSelectedProjectQuery() - const { data: org } = useSelectedOrganizationQuery() - - const { can: canCreateSQLSnippet } = useAsyncCheckPermissions( - PermissionAction.CREATE, - 'user_content', - { - resource: { type: 'sql', owner_id: profile?.id }, - subject: { id: profile?.id }, - } - ) - // [Joshen] Using a ref as this data doesn't need to trigger a re-render const chartConfig = useRef({ view: 'table', @@ -267,13 +140,10 @@ export const MarkdownPre = ({ const snippetId = snippetProps.id const title = snippetProps.title || (language === 'edge' ? 'Edge Function' : 'SQL Query') const isChart = snippetProps.isChart === 'true' - const runQuery = snippetProps.runQuery === 'true' - const results = snap.getCachedSQLResults({ messageId: id, snippetId }) - // Strip props from the content for both SQL and edge functions const cleanContent = rawContent.replace(/(?:--|\/\/)\s*props:\s*\{[^}]+\}/, '').trim() - const isDraggableToReports = canCreateSQLSnippet && router.pathname.endsWith('/reports/[id]') + const toolCallId = String(snippetId ?? id) useEffect(() => { chartConfig.current = { @@ -285,29 +155,6 @@ export const MarkdownPre = ({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [snippetProps]) - const onResultsReturned = useCallback( - (results: any[]) => { - onResults({ messageId: id, resultId: snippetProps.id, results }) - }, - [onResults, snippetProps.id] - ) - - const onRunQuery = async (queryType: 'select' | 'mutation') => { - sendEvent({ - action: 'assistant_suggestion_run_query_clicked', - properties: { - queryType, - ...(queryType === 'mutation' - ? { category: identifyQueryType(cleanContent) ?? 'unknown' } - : {}), - }, - groups: { - project: project?.ref ?? 'Unknown', - organization: org?.slug ?? 'Unknown', - }, - }) - } - return (
{language === 'edge' ? ( @@ -320,27 +167,27 @@ export const MarkdownPre = ({ ) : language === 'sql' ? ( readOnly ? ( + ) : isLoading ? ( +
+ + Writing SQL... +
) : ( - { - chartConfig.current = { ...chartConfig.current, ...config } + ) => { - e.dataTransfer.setData( - 'application/json', - JSON.stringify({ label: title, sql: cleanContent, config: chartConfig.current }) - ) + onError={() => {}} + showConfirmFooter={false} + onChartConfigChange={(config) => { + chartConfig.current = { ...config } }} /> ) diff --git a/apps/studio/components/ui/AIAssistantPanel/elements/Reasoning.tsx b/apps/studio/components/ui/AIAssistantPanel/elements/Reasoning.tsx deleted file mode 100644 index efd7f80be9..0000000000 --- a/apps/studio/components/ui/AIAssistantPanel/elements/Reasoning.tsx +++ /dev/null @@ -1,57 +0,0 @@ -import { BrainIcon, ChevronDownIcon, Loader2 } from 'lucide-react' -import type { ComponentProps } from 'react' -import { memo } from 'react' -import ReactMarkdown from 'react-markdown' - -import { - cn, - Collapsible, - CollapsibleContent_Shadcn_ as CollapsibleContent, - CollapsibleTrigger_Shadcn_ as CollapsibleTrigger, -} from 'ui' - -type ReasoningProps = Omit, 'children'> & { - isStreaming?: boolean - children: string - showReasoning?: boolean -} - -export const Reasoning = memo( - ({ className, isStreaming, showReasoning, children, ...props }: ReasoningProps) => ( - - - {isStreaming ? ( - <> - -

Thinking...

- - ) : ( - <> - -

Reasoned

- - )} - {showReasoning && ( - - )} -
- - - {children} - -
- ) -) - -Reasoning.displayName = 'Reasoning' diff --git a/apps/studio/components/ui/AIAssistantPanel/elements/Tool.tsx b/apps/studio/components/ui/AIAssistantPanel/elements/Tool.tsx new file mode 100644 index 0000000000..458557ccc4 --- /dev/null +++ b/apps/studio/components/ui/AIAssistantPanel/elements/Tool.tsx @@ -0,0 +1,54 @@ +import type { PropsWithChildren } from 'react' + +import { + cn, + Collapsible, + CollapsibleContent_Shadcn_ as CollapsibleContent, + CollapsibleTrigger_Shadcn_ as CollapsibleTrigger, +} from 'ui' + +type ToolProps = PropsWithChildren<{ + className?: string + label: string | JSX.Element + icon?: JSX.Element +}> + +export function Tool({ className, label, icon, children }: ToolProps) { + const isCollapsible = !!children + + return ( +
+ + + {icon} + {typeof label === 'string' ? ( + {label} + ) : ( + label + )} + + + {isCollapsible && ( + + {children} + + )} + +
+ ) +} + +Tool.displayName = 'Tool' diff --git a/apps/studio/components/ui/EdgeFunctionBlock/EdgeFunctionBlock.tsx b/apps/studio/components/ui/EdgeFunctionBlock/EdgeFunctionBlock.tsx index 56ed55caff..b262a2f64d 100644 --- a/apps/studio/components/ui/EdgeFunctionBlock/EdgeFunctionBlock.tsx +++ b/apps/studio/components/ui/EdgeFunctionBlock/EdgeFunctionBlock.tsx @@ -1,16 +1,9 @@ import { Code } from 'lucide-react' import Link from 'next/link' -import { DragEvent, ReactNode, useState } from 'react' -import { toast } from 'sonner' +import type { DragEvent, ReactNode } from 'react' -import { useParams } from 'common' import { ReportBlockContainer } from 'components/interfaces/Reports/ReportBlock/ReportBlockContainer' -import { useProjectSettingsV2Query } from 'data/config/project-settings-v2-query' -import { useEdgeFunctionQuery } from 'data/edge-functions/edge-function-query' -import { useEdgeFunctionDeployMutation } from 'data/edge-functions/edge-functions-deploy-mutation' -import { useSendEventMutation } from 'data/telemetry/send-event-mutation' -import { useSelectedOrganizationQuery } from 'hooks/misc/useSelectedOrganization' -import { Button, cn, CodeBlock, CodeBlockLang } from 'ui' +import { Button, CodeBlock, type CodeBlockLang, cn } from 'ui' import { Admonition } from 'ui-patterns' interface EdgeFunctionBlockProps { @@ -29,7 +22,31 @@ interface EdgeFunctionBlockProps { /** Tooltip when hovering over the header of the block */ tooltip?: ReactNode /** Optional callback on drag start */ - onDragStart?: (e: DragEvent) => void + onDragStart?: (e: DragEvent) => void + /** Hide the header deploy button (used when an external confirm footer is shown) */ + hideDeployButton?: boolean + /** Disable interactive actions */ + disabled?: boolean + /** Whether a deploy action is currently running */ + isDeploying?: boolean + /** Whether a deploy action has completed */ + isDeployed?: boolean + /** Optional message to show when deployment fails */ + errorText?: string + /** URL to the deployed function */ + functionUrl?: string + /** Link to the function details page */ + deploymentDetailsUrl?: string + /** CLI command to download the function */ + downloadCommand?: string + /** Show warning UI when replacing an existing function */ + showReplaceWarning?: boolean + /** Cancel handler when replacing an existing function */ + onCancelReplace?: () => void + /** Confirm handler when replacing an existing function */ + onConfirmReplace?: () => void + /** Handler for triggering a deploy */ + onDeploy?: () => void } export const EdgeFunctionBlock = ({ @@ -37,89 +54,56 @@ export const EdgeFunctionBlock = ({ code, functionName, actions, - showCode: _showCode = false, tooltip, + hideDeployButton = false, + disabled = false, + isDeploying = false, + isDeployed = false, + errorText, + functionUrl, + deploymentDetailsUrl, + downloadCommand, + showReplaceWarning = false, + onCancelReplace, + onConfirmReplace, + onDeploy, + draggable = false, + onDragStart, }: EdgeFunctionBlockProps) => { - const { ref } = useParams() - const [isDeployed, setIsDeployed] = useState(false) - const [showWarning, setShowWarning] = useState(false) - const { data: settings } = useProjectSettingsV2Query({ projectRef: ref }) - const { data: existingFunction } = useEdgeFunctionQuery({ projectRef: ref, slug: functionName }) + const resolvedFunctionUrl = functionUrl ?? 'Function URL will be available after deployment' + const resolvedDownloadCommand = downloadCommand ?? `supabase functions download ${functionName}` - const { mutate: sendEvent } = useSendEventMutation() - const { data: org } = useSelectedOrganizationQuery() + const hasStatusMessage = isDeploying || isDeployed || !!errorText - const { mutateAsync: deployFunction, isLoading: isDeploying } = useEdgeFunctionDeployMutation({ - onSuccess: () => { - setIsDeployed(true) - toast.success('Successfully deployed edge function') - }, - }) - - const handleDeploy = async () => { - if (!code || isDeploying || !ref) return - - if (existingFunction) { - return setShowWarning(true) - } - - try { - await deployFunction({ - projectRef: ref, - slug: functionName, - metadata: { - entrypoint_path: 'index.ts', - name: functionName, - verify_jwt: true, - }, - files: [{ name: 'index.ts', content: code }], - }) - sendEvent({ - action: 'edge_function_deploy_button_clicked', - properties: { origin: 'functions_ai_assistant' }, - groups: { project: ref ?? 'Unknown', organization: org?.slug ?? 'Unknown' }, - }) - } catch (error) { - toast.error( - `Failed to deploy function: ${error instanceof Error ? error.message : 'Unknown error'}` - ) - } - } - - let functionUrl = 'Function URL not available' - const endpoint = settings?.app_config?.endpoint - if (endpoint) { - const restUrl = `https://${endpoint}` - const restUrlTld = restUrl ? new URL(restUrl).hostname.split('.').pop() : 'co' - functionUrl = - ref && functionName && restUrlTld - ? `https://${ref}.supabase.${restUrlTld}/functions/v1/${functionName}` - : 'Function URL will be available after deployment' - } return ( } label={label} + loading={isDeploying} + draggable={draggable} + onDragStart={onDragStart} actions={ - ref && functionName ? ( + hideDeployButton || !onDeploy ? ( + actions ?? null + ) : ( <> {actions} - ) : null + ) } > - {showWarning && ref && functionName && ( + {showReplaceWarning && ( setShowWarning(false)} + disabled={isDeploying} + onClick={onCancelReplace} > Cancel @@ -141,25 +126,9 @@ export const EdgeFunctionBlock = ({ type="danger" size="tiny" className="w-full flex-1" - onClick={async () => { - setShowWarning(false) - try { - await deployFunction({ - projectRef: ref, - slug: functionName, - metadata: { - entrypoint_path: 'index.ts', - name: functionName, - verify_jwt: true, - }, - files: [{ name: 'index.ts', content: code }], - }) - } catch (error) { - toast.error( - `Failed to deploy function: ${error instanceof Error ? error.message : 'Unknown error'}` - ) - } - }} + loading={isDeploying} + disabled={isDeploying} + onClick={onConfirmReplace} > Replace function @@ -180,26 +149,29 @@ export const EdgeFunctionBlock = ({ />
- {(isDeploying || isDeployed) && ( + {hasStatusMessage && (
{isDeploying ? (

Deploying function...

+ ) : errorText ? ( +

{errorText}

) : ( <>

The{' '} - - new function - {' '} + {deploymentDetailsUrl ? ( + + new function + + ) : ( + new function + )}{' '} is now live at:

@@ -208,7 +180,7 @@ export const EdgeFunctionBlock = ({ diff --git a/apps/studio/components/ui/EditorPanel/EditorPanel.tsx b/apps/studio/components/ui/EditorPanel/EditorPanel.tsx index a5f8331aee..7dfa1aafbc 100644 --- a/apps/studio/components/ui/EditorPanel/EditorPanel.tsx +++ b/apps/studio/components/ui/EditorPanel/EditorPanel.tsx @@ -50,7 +50,7 @@ import { containsUnknownFunction, isReadOnlySelect } from '../AIAssistantPanel/A import AIEditor from '../AIEditor' import { ButtonTooltip } from '../ButtonTooltip' import { InlineLink } from '../InlineLink' -import SqlWarningAdmonition from '../SqlWarningAdmonition' +import { SqlWarningAdmonition } from '../SqlWarningAdmonition' type Template = { name: string diff --git a/apps/studio/components/ui/QueryBlock/QueryBlock.tsx b/apps/studio/components/ui/QueryBlock/QueryBlock.tsx index f917e0391a..ba211604c0 100644 --- a/apps/studio/components/ui/QueryBlock/QueryBlock.tsx +++ b/apps/studio/components/ui/QueryBlock/QueryBlock.tsx @@ -1,25 +1,19 @@ import dayjs from 'dayjs' import { Code, Play } from 'lucide-react' -import { DragEvent, ReactNode, useEffect, useMemo, useState } from 'react' +import { DragEvent, ReactNode, useEffect, useMemo, useRef, useState } from 'react' import { Bar, BarChart, CartesianGrid, Cell, Tooltip, XAxis, YAxis } from 'recharts' -import { toast } from 'sonner' -import { useParams } from 'common' import { ReportBlockContainer } from 'components/interfaces/Reports/ReportBlock/ReportBlockContainer' import { ChartConfig } from 'components/interfaces/SQLEditor/UtilityPanel/ChartConfig' import Results from 'components/interfaces/SQLEditor/UtilityPanel/Results' -import { usePrimaryDatabase } from 'data/read-replicas/replicas-query' -import { type QueryResponseError, useExecuteSqlMutation } from 'data/sql/execute-sql-mutation' -import { type Parameter, parseParameters } from 'lib/sql-parameters' -import type { Dashboards } from 'types' -import { ChartContainer, ChartTooltipContent, cn, CodeBlock, SQL_ICON } from 'ui' + +import { Badge, Button, ChartContainer, ChartTooltipContent, cn, CodeBlock } from 'ui' import ShimmeringLoader from 'ui-patterns/ShimmeringLoader' import { ButtonTooltip } from '../ButtonTooltip' import { CHART_COLORS } from '../Charts/Charts.constants' -import SqlWarningAdmonition from '../SqlWarningAdmonition' +import { SqlWarningAdmonition } from '../SqlWarningAdmonition' import { BlockViewConfiguration } from './BlockViewConfiguration' import { EditQueryButton } from './EditQueryButton' -import { ParametersPopover } from './ParametersPopover' import { getCumulativeResults } from './QueryBlock.utils' export const DEFAULT_CHART_CONFIG: ChartConfig = { @@ -32,65 +26,24 @@ export const DEFAULT_CHART_CONFIG: ChartConfig = { view: 'table', } -interface QueryBlockProps { - /** Applicable if SQL is a snippet that's already saved (Used in Reports) */ +export interface QueryBlockProps { id?: string - /** Title of the QueryBlock */ label: string - /** SQL query to render/run in the QueryBlock */ sql?: string - /** Configuration of the output chart based on the query result */ + isWriteQuery?: boolean chartConfig?: ChartConfig - /** Not implemented yet: Will be the next part of ReportsV2 */ - parameterValues?: Record - /** Any other actions specific to the parent to be rendered in the header */ actions?: ReactNode - /** Toggle visiblity of SQL query on render */ - showSql?: boolean - /** Indicate if SQL query can be rendered as a chart */ - isChart?: boolean - /** For Assistant as QueryBlock is rendered while streaming response */ - isLoading?: boolean - /** Override to prevent running the SQL query provided */ - runQuery?: boolean - /** Prevent updating of columns for X and Y axes in the chart view */ - lockColumns?: boolean - /** Max height set to render results / charts (Defaults to 250) */ - maxHeight?: number - /** Whether query block is draggable */ - draggable?: boolean - /** Tooltip when hovering over the header of the block (Used in Assistant Panel) */ - tooltip?: ReactNode - /** Optional: Any initial results to render as part of the query*/ results?: any[] - /** Opt to show run button if query is not read only */ - showRunButtonIfNotReadOnly?: boolean - /** Not implemented yet: Will be the next part of ReportsV2 */ - onSetParameter?: (params: Parameter[]) => void - /** Optional callback the SQL query is run */ - onRunQuery?: (queryType: 'select' | 'mutation') => void - /** Optional callback on drag start */ + errorText?: string + isExecuting?: boolean + initialHideSql?: boolean + draggable?: boolean + disabled?: boolean + blockWriteQueries?: boolean + onExecute?: (queryType: 'select' | 'mutation') => void + onRemoveChart?: () => void + onUpdateChartConfig?: ({ chartConfig }: { chartConfig: Partial }) => void onDragStart?: (e: DragEvent) => void - /** Optional: callback when the results are returned from running the SQL query*/ - onResults?: (results: any[]) => void - - // [Joshen] Params below are currently only used by ReportsV2 (Might revisit to see how to improve these) - /** Optional height set to render the SQL query (Used in Reports) */ - queryHeight?: number - /** UI to render if there's a read-only error while running the query */ - readOnlyErrorPlaceholder?: ReactNode - /** UI to render if there's no query results (Used in Reports) */ - noResultPlaceholder?: ReactNode - /** To trigger a refresh of the query */ - isRefreshing?: boolean - /** Optional callback whenever a chart configuration is updated (Used in Reports) */ - onUpdateChartConfig?: ({ - chart, - chartConfig, - }: { - chart?: Partial - chartConfig: Partial - }) => void } // [Joshen ReportsV2] JFYI we may adjust this in subsequent PRs when we implement this into Reports V2 @@ -100,90 +53,58 @@ export const QueryBlock = ({ label, sql, chartConfig = DEFAULT_CHART_CONFIG, - maxHeight = 250, - queryHeight, - parameterValues: extParameterValues, actions, - showSql: _showSql = false, - isChart = false, - isLoading = false, - runQuery = false, - lockColumns = false, - draggable = false, - isRefreshing = false, - noResultPlaceholder = null, - readOnlyErrorPlaceholder = null, - showRunButtonIfNotReadOnly = false, - tooltip, results, - onRunQuery, - onSetParameter, + errorText, + isWriteQuery = false, + isExecuting = false, + initialHideSql = false, + draggable = false, + disabled = false, + blockWriteQueries = false, + onExecute, + onRemoveChart, onUpdateChartConfig, onDragStart, - onResults, }: QueryBlockProps) => { - const { ref } = useParams() - const [chartSettings, setChartSettings] = useState(chartConfig) const { xKey, yKey, view = 'table' } = chartSettings - const [showSql, setShowSql] = useState(_showSql) - const [readOnlyError, setReadOnlyError] = useState(false) - const [queryError, setQueryError] = useState() - const [queryResult, setQueryResult] = useState(results) + const [showSql, setShowSql] = useState(!results && !initialHideSql) const [focusDataIndex, setFocusDataIndex] = useState() + const [showWarning, setShowWarning] = useState<'hasWriteOperation' | 'hasUnknownFunctions'>() + + const prevIsWriteQuery = useRef(isWriteQuery) + + useEffect(() => { + if (!prevIsWriteQuery.current && isWriteQuery) { + setShowWarning('hasWriteOperation') + } + if (!isWriteQuery && showWarning === 'hasWriteOperation') { + setShowWarning(undefined) + } + prevIsWriteQuery.current = isWriteQuery + }, [isWriteQuery, showWarning]) + + useEffect(() => { + setChartSettings(chartConfig) + }, [chartConfig]) const formattedQueryResult = useMemo(() => { - // Make sure Y axis values are numbers - return queryResult?.map((row) => { + return results?.map((row) => { return Object.fromEntries( Object.entries(row).map(([key, value]) => { if (key === yKey) return [key, Number(value)] - else return [key, value] + return [key, value] }) ) }) - }, [queryResult, yKey]) - - const [parameterValues, setParameterValues] = useState>({}) - const [showWarning, setShowWarning] = useState<'hasWriteOperation' | 'hasUnknownFunctions'>() - - const parameters = useMemo(() => { - if (!sql) return [] - return parseParameters(sql) - }, [sql]) - // [Joshen] This is for when we introduced the concept of parameters into our reports - // const combinedParameterValues = { ...extParameterValues, ...parameterValues } - - const { database: primaryDatabase } = usePrimaryDatabase({ projectRef: ref }) - const postgresConnectionString = primaryDatabase?.connectionString - const readOnlyConnectionString = primaryDatabase?.connection_string_read_only + }, [results, yKey]) const chartData = chartSettings.cumulative ? getCumulativeResults({ rows: formattedQueryResult ?? [] }, chartSettings) : formattedQueryResult - const { mutate: execute, isLoading: isExecuting } = useExecuteSqlMutation({ - onSuccess: (data) => { - onResults?.(data.result) - setQueryResult(data.result) - - setReadOnlyError(false) - setQueryError(undefined) - }, - onError: (error) => { - const readOnlyTransaction = /cannot execute .+ in a read-only transaction/.test(error.message) - const permissionDenied = error.message.includes('permission denied') - const notOwner = error.message.includes('must be owner') - if (readOnlyTransaction || permissionDenied || notOwner) { - setReadOnlyError(true) - if (showRunButtonIfNotReadOnly) setShowWarning('hasWriteOperation') - } else { - setQueryError(error) - } - }, - }) - const getDateFormat = (key: any) => { const value = chartData?.[0]?.[key] || '' if (typeof value === 'number') return 'number' @@ -192,176 +113,111 @@ export const QueryBlock = ({ } const xKeyDateFormat = getDateFormat(xKey) - const handleExecute = () => { - if (!sql || isLoading) return + const hasResults = Array.isArray(results) && results.length > 0 - if (readOnlyError) { - return setShowWarning('hasWriteOperation') - } - - try { - // [Joshen] This is for when we introduced the concept of parameters into our reports - // const processedSql = processParameterizedSql(sql, combinedParameterValues) - execute({ - projectRef: ref, - connectionString: readOnlyConnectionString, - sql, - }) - } catch (error: any) { - toast.error(`Failed to execute query: ${error.message}`) + const runSelect = () => { + if (!sql || disabled || isExecuting) return + if (isWriteQuery) { + setShowWarning('hasWriteOperation') + return } + onExecute?.('select') } - useEffect(() => { - setChartSettings(chartConfig) - }, [chartConfig]) - - // Run once on mount to parse parameters and notify parent - useEffect(() => { - if (!!sql && onSetParameter) { - const params = parseParameters(sql) - onSetParameter(params) - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [sql]) - - useEffect(() => { - if (!!sql && !isLoading && runQuery && !!readOnlyConnectionString && !readOnlyError) { - handleExecute() - } - }, [sql, isLoading, runQuery, readOnlyConnectionString]) - - useEffect(() => { - if (isRefreshing) handleExecute() - }, [isRefreshing]) + const runMutation = () => { + if (!sql || disabled || isExecuting) return + setShowWarning(undefined) + onExecute?.('mutation') + } return ( ) => onDragStart?.(e)} - icon={ - - } + loading={isExecuting} label={label} + badge={isWriteQuery && Write} actions={ - <> - } - onClick={() => setShowSql(!showSql)} - tooltip={{ - content: { side: 'bottom', text: showSql ? 'Hide query' : 'Show query' }, - }} - /> + disabled ? null : ( + <> + } + onClick={() => setShowSql(!showSql)} + tooltip={{ + content: { side: 'bottom', text: showSql ? 'Hide query' : 'Show query' }, + }} + /> + {hasResults && ( + { + if (onUpdateChartConfig) onUpdateChartConfig({ chartConfig: { view: nextView } }) + setChartSettings({ ...chartSettings, view: nextView }) + }} + updateChartConfig={(config) => { + if (onUpdateChartConfig) onUpdateChartConfig({ chartConfig: config }) + setChartSettings(config) + }} + /> + )} - {queryResult && ( - <> - {/* [Joshen ReportsV2] Won't see this just yet as this is intended for Reports V2 */} - {parameters.length > 0 && ( - - )} - {isChart && ( - { - if (onUpdateChartConfig) onUpdateChartConfig({ chartConfig: { view } }) - setChartSettings({ ...chartSettings, view }) - }} - updateChartConfig={(config) => { - if (onUpdateChartConfig) onUpdateChartConfig({ chartConfig: config }) - setChartSettings(config) - }} - /> - )} - - )} - - - - {(showRunButtonIfNotReadOnly || !readOnlyError) && ( + } - loading={isExecuting || isLoading} - disabled={isLoading} - onClick={() => { - handleExecute() - if (!!sql) onRunQuery?.('select') - }} + loading={isExecuting} + disabled={isExecuting || disabled || !sql} + onClick={runSelect} tooltip={{ content: { side: 'bottom', className: 'max-w-56 text-center', - text: isExecuting ? ( -

{`Query is running. You may cancel ongoing queries via the [SQL Editor](/project/${ref}/sql?viewOngoingQueries=true).`}

- ) : ( - 'Run query' - ), + text: isExecuting + ? 'Query is running. Check the SQL Editor to manage running queries.' + : 'Run query', }, }} /> - )} - {actions} - + {actions} + + ) } > - {!!showWarning && ( + {!!showWarning && !blockWriteQueries && ( setShowWarning(undefined)} - onConfirm={() => { - // [Joshen] This is for when we introduced the concept of parameters into our reports - // const processedSql = processParameterizedSql(sql!, combinedParameterValues) - if (sql) { - setShowWarning(undefined) - execute({ - projectRef: ref, - connectionString: postgresConnectionString, - sql, - }) - onRunQuery?.('mutation') - } - }} + onConfirm={runMutation} disabled={!sql} + {...(showWarning !== 'hasWriteOperation' + ? { + message: 'Run this query now and send the results to the Assistant? ', + subMessage: + 'We will execute the query and provide the result rows back to the Assistant to continue the conversation.', + cancelLabel: 'Skip', + confirmLabel: 'Run & send', + } + : {})} /> )} - {isExecuting && queryResult === undefined && ( -
- -
- )} - {showSql && (
)} - {view === 'chart' && queryResult !== undefined ? ( + {isExecuting && !results && ( +
+ +
+ )} + + {view === 'chart' && results !== undefined ? ( <> - {(queryResult ?? []).length === 0 ? ( + {(results ?? []).length === 0 ? (

No results returned from query

@@ -390,10 +252,7 @@ export const QueryBlock = ({
) : ( <> - {!isExecuting && !!queryError ? ( -
- ERROR: {queryError.message} + {isWriteQuery && blockWriteQueries ? ( +
+

+ SQL query is not read-only and cannot be rendered +

+

+ Queries that involve any mutation will not be run in reports +

+ {!!onRemoveChart && ( + + )}
- ) : queryResult ? ( -
- + ) : !isExecuting && !!errorText ? ( +
+ ERROR: {errorText}
- ) : !isExecuting ? ( - readOnlyError ? ( - readOnlyErrorPlaceholder - ) : ( - noResultPlaceholder + ) : ( + results && ( +
+ +
) - ) : null} + )} )} diff --git a/apps/studio/components/ui/SchemaSelector.tsx b/apps/studio/components/ui/SchemaSelector.tsx index e3f89590ff..62cb0f90af 100644 --- a/apps/studio/components/ui/SchemaSelector.tsx +++ b/apps/studio/components/ui/SchemaSelector.tsx @@ -124,7 +124,7 @@ const SchemaSelector = ({
) : (
-

Choose a schema…

+

Choose a schema...

)} diff --git a/apps/studio/components/ui/SqlWarningAdmonition.tsx b/apps/studio/components/ui/SqlWarningAdmonition.tsx index e199b57d48..b7c9ed8a7e 100644 --- a/apps/studio/components/ui/SqlWarningAdmonition.tsx +++ b/apps/studio/components/ui/SqlWarningAdmonition.tsx @@ -7,32 +7,46 @@ export interface SqlWarningAdmonitionProps { onConfirm: () => void disabled?: boolean className?: string + /** Optional override primary message */ + message?: string + /** Optional override secondary message */ + subMessage?: string + /** Optional override labels */ + cancelLabel?: string + confirmLabel?: string } -const SqlWarningAdmonition = ({ +export const SqlWarningAdmonition = ({ warningType, onCancel, onConfirm, disabled = false, className, + message, + subMessage, + cancelLabel, + confirmLabel, }: SqlWarningAdmonitionProps) => { return ( -

- {warningType === 'hasWriteOperation' - ? 'This query contains write operations.' - : 'This query involves running a function.'}{' '} - Are you sure you want to execute it? -

+ {!!message && ( +

+ {`${ + warningType === 'hasWriteOperation' + ? 'This query contains write operations.' + : 'This query involves running a function.' + } Are you sure you want to execute it?`} +

+ )}

- Make sure you are not accidentally removing something important. + {subMessage ?? 'Make sure you are not accidentally removing something important.'}

) } - -export default SqlWarningAdmonition diff --git a/apps/studio/hooks/misc/useChanged.ts b/apps/studio/hooks/misc/useChanged.ts index afeeb0f319..63431717f1 100644 --- a/apps/studio/hooks/misc/useChanged.ts +++ b/apps/studio/hooks/misc/useChanged.ts @@ -10,3 +10,11 @@ export function useChanged(value: T): boolean { return changed } + +export function useChangedSync(value: T): boolean { + const prev = useRef() + const changed = prev.current !== value + prev.current = value + + return changed +} diff --git a/apps/studio/lib/ai/message-utils.ts b/apps/studio/lib/ai/message-utils.ts new file mode 100644 index 0000000000..2b0cbd6292 --- /dev/null +++ b/apps/studio/lib/ai/message-utils.ts @@ -0,0 +1,25 @@ +import type { UIMessage } from 'ai' + +/** + * Prepares messages for API transmission by cleaning and limiting history + */ +export function prepareMessagesForAPI(messages: UIMessage[]): UIMessage[] { + // [Joshen] Specifically limiting the chat history that get's sent to reduce the + // size of the context that goes into the model. This should always be an odd number + // as much as possible so that the first message is always the user's + const MAX_CHAT_HISTORY = 7 + + const slicedMessages = messages.slice(-MAX_CHAT_HISTORY) + + // Filter out results from messages before sending to the model + const cleanedMessages = slicedMessages.map((_message) => { + const message = _message as UIMessage & { results?: unknown } + const cleanedMessage = { ...message } as UIMessage & { results?: unknown } + if (message.role === 'assistant' && message.results) { + delete cleanedMessage.results + } + return cleanedMessage as UIMessage + }) + + return cleanedMessages +} diff --git a/apps/studio/lib/ai/prompts.ts b/apps/studio/lib/ai/prompts.ts index 43df8afb19..3f2ca1d42c 100644 --- a/apps/studio/lib/ai/prompts.ts +++ b/apps/studio/lib/ai/prompts.ts @@ -282,50 +282,54 @@ Developer: # Role and Objective - Be aware that tool access may be restricted depending on the user's organization settings. - Do not try to bypass tool restrictions by executing SQL e.g. writing a query to retrieve database schema information. Instead, explain to the user you do not have permissions to use the tools you need to execute the task -# Output Format +## Output Format - Always integrate findings from the tools seamlessly into your responses for better accuracy and context. -# Searching Docs +## Searching Docs - Use \`search_docs\` to search the Supabase documentation for relevant information when the question is about Supabase features or complex database operations ` export const CHAT_PROMPT = ` -Developer: # Response Style -- Be direct and concise. Provide only essential information. -- Use lists to present information; do not use tables for formatting. -- Minimize use of emojis. +# Response Style +- Be professional, direct and concise. Provide only essential information. +- Before context gathering or tool usage, summarise the user's request and your plan of action in a single paragraph +- Do not repeat yourself or your plan after context gathering # Response Format -## Markdown -- Follow the CommonMark specification. -- Use a logical heading hierarchy (H1–H4), maintaining order without skipping levels. +## Use Markdown +- *CRITICAL*: Response must be in markdown format. +- Always use markdown blocks **where semantically correct** (e.g., \`inline code\`, \`\`\`code fences\`\`\`, headings, lists, tables). +- Make use of markdown headings to structure your response where appropriate. e.g. WRONG "Section heading: ..." WRITE "## Section heading" +- Shorter responses do not need headings - Use bold text exclusively to emphasize key information. - Do not use tables for displaying information under any circumstances. +- Minimize use of emojis. # Chat Naming - At the start of each conversation, if the chat has not yet been named, invoke \`rename_chat\` with a descriptive 2–4 word name. Examples: "User Authentication Setup", "Sales Data Analysis", "Product Table Creation". -## Task Workflow -- Always start the conversation with a concise checklist of sub-tasks you will perform before generating outputs or calling tools. Keep the checklist conceptual, not implementation-level. -- No need to repeat the checklist later in the conversation - # SQL Execution and Display - Be confident: assume the user is the project owner. You do not need to show code before execution. -- To actually run or display SQL, directly call the \`display_query\` tool. The user will be able to run the query and view the results -- If multiple queries are needed, call \`display_query\` separately for each and validate results in 1–2 lines. -- You will not have access to the results unless the user returns the results to you +- To actually run SQL, directly call the \`execute_sql\` tool with the \`sql\` string. The client will request user confirmation and then return results. +- If executing SQL returns an error, explain the error concisely and try again with the correct SQL. +- The user may skip executing the query in which case you should acknowledge the skip and offer alternative options or actions to take +- If the user asks you to write a query, or if you want to show example SQL without executing, render it in a markdown code block (e.g.: \`\`\`sql). Do this only when the user asks to see the code or for illustrative examples. +- If multiple queries are needed, call \`execute_sql\` separately for each and validate results in 1–2 lines. Use separate code blocks only for non-executed examples. +- After executing queries, summarize the outcome and confirm next actions or self-correct as needed. +- You do not need to repeat the SQL query results as the client will display them to the user as part of the execute_sql tool call. # Edge Functions -- Be confident: assume the user is the project owner. -- To deploy an Edge Function, directly call the \`display_edge_function\` tool. The client will allow the user to deploy the function. -- You will not have access to the results unless the user returns the results to you -- To show example Edge Function code without deploying, you should also call the \`display_edge_function\` tool with the code. +- Be confident: assume the user is the project owner. You do not need to show code before deployment. +- To deploy an Edge Function, directly call the \`deploy_edge_function\` tool with \`name\` and \`code\`. The client will request user confirmation and then deploy, returning the result. +- To show example Edge Function code without deploying, render it in a markdown code block (e.g.: \`\`\`edge\` or \`\`\`typescript\`). Do this only when the user asks to see the code or for illustrative examples. +- Only use \`deploy_edge_function\` when the function should be deployed, not for examples or non-executable code. # Project Health Checks - Use \`get_advisors\` to identify project issues. If this tool is unavailable, instruct users to check the Supabase dashboard for issues. +- Use \`get_logs\` to retrieve recent logs for the project # Safety for Destructive Queries -- For destructive commands (e.g., DROP TABLE, DELETE without WHERE clause), always ask for confirmation before calling the \`display_query\` tool. +- For destructive commands (e.g., DROP TABLE, DELETE without WHERE clause), always ask for confirmation before calling the \`execute_sql\` tool. ` export const OUTPUT_ONLY_PROMPT = ` diff --git a/apps/studio/lib/ai/test-fixtures.ts b/apps/studio/lib/ai/test-fixtures.ts new file mode 100644 index 0000000000..1270e79576 --- /dev/null +++ b/apps/studio/lib/ai/test-fixtures.ts @@ -0,0 +1,114 @@ +import type { ToolUIPart, UIMessage } from 'ai' + +export function createUserMessage(content: string, id = 'user-msg-1'): UIMessage { + return { + id, + role: 'user', + parts: [ + { + type: 'text', + text: content, + }, + ], + } +} + +export function createAssistantTextMessage(content: string, id = 'assistant-msg-1'): UIMessage { + return { + id, + role: 'assistant', + parts: [ + { + type: 'text', + text: content, + }, + ], + } +} + +export function createAssistantMessageWithExecuteSqlTool( + query: string, + results: Array> = [{ id: 1, name: 'test' }], + id = 'assistant-tool-msg-1' +): UIMessage { + return { + id, + role: 'assistant', + parts: [ + { + type: 'text', + text: "I'll run that SQL query for you.", + }, + { + type: 'tool-execute_sql', + state: 'output-available', + toolCallId: 'call-123', + input: { sql: query }, + output: results, + } satisfies ToolUIPart, + ], + } +} + +export function createAssistantMessageWithMultipleTools( + id = 'assistant-multi-tool-msg-1' +): UIMessage { + return { + id, + role: 'assistant', + parts: [ + { + type: 'text', + text: 'Let me check the database structure and run some queries.', + }, + { + type: 'tool-execute_sql', + state: 'output-available', + toolCallId: 'call-456', + input: { sql: 'SELECT * FROM users LIMIT 5' }, + output: [ + { id: 1, email: 'user1@example.com' }, + { id: 2, email: 'user2@example.com' }, + ], + } satisfies ToolUIPart, + { + type: 'tool-execute_sql', + state: 'output-available', + toolCallId: 'call-789', + toolName: 'execute_sql', + input: { sql: 'DESCRIBE users' }, + output: [ + { column: 'id', type: 'integer', nullable: false }, + { column: 'email', type: 'varchar', nullable: false }, + ], + } as ToolUIPart, + ], + } +} + +export function createLongConversation(): Array { + return [ + createUserMessage('Show me all users', 'msg-1'), + createAssistantMessageWithExecuteSqlTool('SELECT * FROM users', [{ id: 1 }], 'msg-2'), + createUserMessage('How many users are there?', 'msg-3'), + createAssistantMessageWithExecuteSqlTool( + 'SELECT COUNT(*) FROM users', + [{ count: 100 }], + 'msg-4' + ), + createUserMessage('Show me the schema', 'msg-5'), + createAssistantTextMessage("Here's the database schema...", 'msg-6'), + createUserMessage('Create a new table', 'msg-7'), + createAssistantMessageWithExecuteSqlTool( + 'CREATE TABLE posts (id SERIAL PRIMARY KEY)', + [], + 'msg-8' + ), + createUserMessage('Add some data', 'msg-9'), + createAssistantMessageWithExecuteSqlTool( + "INSERT INTO posts (title) VALUES ('Test')", + [], + 'msg-10' + ), + ] +} diff --git a/apps/studio/lib/ai/tool-filter.test.ts b/apps/studio/lib/ai/tool-filter.test.ts index 5ca60c9b7f..fbfbe60519 100644 --- a/apps/studio/lib/ai/tool-filter.test.ts +++ b/apps/studio/lib/ai/tool-filter.test.ts @@ -12,7 +12,7 @@ import { describe('TOOL_CATEGORY_MAP', () => { it('should categorize tools correctly', () => { - expect(TOOL_CATEGORY_MAP['display_query']).toBe(TOOL_CATEGORIES.UI) + expect(TOOL_CATEGORY_MAP['execute_sql']).toBe(TOOL_CATEGORIES.UI) expect(TOOL_CATEGORY_MAP['list_tables']).toBe(TOOL_CATEGORIES.SCHEMA) }) }) @@ -22,8 +22,8 @@ describe('tool allowance by opt-in level', () => { function getAllowedTools(optInLevel: string) { const mockTools: ToolSet = { // UI tools - display_query: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, - display_edge_function: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, + execute_sql: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, + deploy_edge_function: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, rename_chat: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, search_docs: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, // Schema tools @@ -53,8 +53,8 @@ describe('tool allowance by opt-in level', () => { it('should return only UI tools for disabled opt-in level', () => { const tools = getAllowedTools('disabled') - expect(tools).toContain('display_query') - expect(tools).toContain('display_edge_function') + expect(tools).toContain('execute_sql') + expect(tools).toContain('deploy_edge_function') expect(tools).toContain('rename_chat') expect(tools).toContain('search_docs') expect(tools).not.toContain('list_tables') @@ -62,13 +62,13 @@ describe('tool allowance by opt-in level', () => { expect(tools).not.toContain('list_edge_functions') expect(tools).not.toContain('list_branches') expect(tools).not.toContain('get_logs') - expect(tools).not.toContain('execute_sql') + expect(tools).not.toContain('get_advisors') }) it('should return UI and schema tools for schema opt-in level', () => { const tools = getAllowedTools('schema') - expect(tools).toContain('display_query') - expect(tools).toContain('display_edge_function') + expect(tools).toContain('execute_sql') + expect(tools).toContain('deploy_edge_function') expect(tools).toContain('rename_chat') expect(tools).toContain('list_tables') expect(tools).toContain('list_extensions') @@ -78,13 +78,12 @@ describe('tool allowance by opt-in level', () => { expect(tools).toContain('search_docs') expect(tools).not.toContain('get_advisors') expect(tools).not.toContain('get_logs') - expect(tools).not.toContain('execute_sql') }) it('should return UI, schema and log tools for schema_and_log opt-in level', () => { const tools = getAllowedTools('schema_and_log') - expect(tools).toContain('display_query') - expect(tools).toContain('display_edge_function') + expect(tools).toContain('execute_sql') + expect(tools).toContain('deploy_edge_function') expect(tools).toContain('rename_chat') expect(tools).toContain('list_tables') expect(tools).toContain('list_extensions') @@ -94,13 +93,12 @@ describe('tool allowance by opt-in level', () => { expect(tools).toContain('search_docs') expect(tools).toContain('get_advisors') expect(tools).toContain('get_logs') - expect(tools).not.toContain('execute_sql') }) - it('should return all tools for schema_and_log_and_data opt-in level (excluding execute_sql)', () => { + it('should return all tools for schema_and_log_and_data opt-in level', () => { const tools = getAllowedTools('schema_and_log_and_data') - expect(tools).toContain('display_query') - expect(tools).toContain('display_edge_function') + expect(tools).toContain('execute_sql') + expect(tools).toContain('deploy_edge_function') expect(tools).toContain('rename_chat') expect(tools).toContain('list_tables') expect(tools).toContain('list_extensions') @@ -110,15 +108,14 @@ describe('tool allowance by opt-in level', () => { expect(tools).toContain('search_docs') expect(tools).toContain('get_advisors') expect(tools).toContain('get_logs') - expect(tools).not.toContain('execute_sql') }) }) describe('filterToolsByOptInLevel', () => { const mockTools: ToolSet = { // UI tools - should return non-privacy responses - display_query: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, - display_edge_function: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, + execute_sql: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, + deploy_edge_function: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, rename_chat: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, // Schema tools list_tables: { execute: vitest.fn().mockResolvedValue({ status: 'success' }) }, @@ -173,8 +170,8 @@ describe('filterToolsByOptInLevel', () => { it('should always allow UI tools regardless of opt-in level', async () => { const tools = filterToolsByOptInLevel(mockTools, 'disabled') - expect(tools).toHaveProperty('display_query') - expect(tools).toHaveProperty('display_edge_function') + expect(tools).toHaveProperty('execute_sql') + expect(tools).toHaveProperty('deploy_edge_function') expect(tools).toHaveProperty('rename_chat') // UI tools should not be stubbed, but managed tools should be @@ -240,7 +237,7 @@ describe('toolSetValidationSchema', () => { it('should accept subset of known tools', () => { const validSubset = { list_tables: { inputSchema: z.object({}), execute: vitest.fn() }, - display_query: { inputSchema: z.object({}), execute: vitest.fn() }, + execute_sql: { inputSchema: z.object({}), execute: vitest.fn() }, } const result = toolSetValidationSchema.safeParse(validSubset) @@ -276,9 +273,10 @@ describe('toolSetValidationSchema', () => { list_policies: { inputSchema: z.object({}), execute: vitest.fn() }, search_docs: { inputSchema: z.object({}), execute: vitest.fn() }, get_advisors: { inputSchema: z.object({}), execute: vitest.fn() }, - display_query: { inputSchema: z.object({}), execute: vitest.fn() }, - display_edge_function: { inputSchema: z.object({}), execute: vitest.fn() }, + execute_sql: { inputSchema: z.object({}), execute: vitest.fn() }, + deploy_edge_function: { inputSchema: z.object({}), execute: vitest.fn() }, rename_chat: { inputSchema: z.object({}), execute: vitest.fn() }, + get_logs: { inputSchema: z.object({}), execute: vitest.fn() }, } const validationResult = toolSetValidationSchema.safeParse(allExpectedTools) diff --git a/apps/studio/lib/ai/tool-filter.ts b/apps/studio/lib/ai/tool-filter.ts index a344b3f03e..d9e0c66d27 100644 --- a/apps/studio/lib/ai/tool-filter.ts +++ b/apps/studio/lib/ai/tool-filter.ts @@ -1,6 +1,8 @@ -import { Tool, ToolSet } from 'ai' +import type { Tool, ToolSet } from 'ai' import { z } from 'zod' -import { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' +// End of third-party imports + +import type { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' // Add the DatabaseExtension type import export type DatabaseExtension = { @@ -28,8 +30,8 @@ export const toolSetValidationSchema = z.record( 'get_logs', // Local tools - 'display_query', - 'display_edge_function', + 'execute_sql', + 'deploy_edge_function', 'rename_chat', 'list_policies', @@ -41,6 +43,7 @@ export const toolSetValidationSchema = z.record( ]), basicToolSchema ) +export type ToolName = keyof z.infer /** * Tool categories based on the data they access @@ -63,8 +66,8 @@ type ToolCategory = (typeof TOOL_CATEGORIES)[keyof typeof TOOL_CATEGORIES] */ export const TOOL_CATEGORY_MAP: Record = { // UI tools - always available - display_query: TOOL_CATEGORIES.UI, - display_edge_function: TOOL_CATEGORIES.UI, + execute_sql: TOOL_CATEGORIES.UI, + deploy_edge_function: TOOL_CATEGORIES.UI, rename_chat: TOOL_CATEGORIES.UI, search_docs: TOOL_CATEGORIES.UI, diff --git a/apps/studio/lib/ai/tools/mcp-tools.ts b/apps/studio/lib/ai/tools/mcp-tools.ts index 65c24779e3..4579d906d2 100644 --- a/apps/studio/lib/ai/tools/mcp-tools.ts +++ b/apps/studio/lib/ai/tools/mcp-tools.ts @@ -1,7 +1,12 @@ -import { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' +import type { ToolSet } from 'ai' +// End of third-party imports + +import type { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' import { createSupabaseMCPClient } from '../supabase-mcp' import { filterToolsByOptInLevel, toolSetValidationSchema } from '../tool-filter' +const UI_EXECUTED_TOOLS = ['execute_sql', 'deploy_edge_function'] + export const getMcpTools = async ({ accessToken, projectRef, @@ -17,18 +22,22 @@ export const getMcpTools = async ({ projectId: projectRef, }) - const availableMcpTools = await mcpClient.tools() + const availableMcpTools = (await mcpClient.tools()) as ToolSet // Filter tools based on the (potentially modified) AI opt-in level const allowedMcpTools = filterToolsByOptInLevel(availableMcpTools, aiOptInLevel) - // Validate that only known tools are provided - const { data: validatedTools, error: validationError } = - toolSetValidationSchema.safeParse(allowedMcpTools) + // Remove UI-executed tools handled locally + const filteredMcpTools: ToolSet = { ...allowedMcpTools } + UI_EXECUTED_TOOLS.forEach((toolName) => { + delete filteredMcpTools[toolName] + }) - if (validationError) { - console.error('MCP tools validation error:', validationError) + // Validate that only known tools are provided + const validation = toolSetValidationSchema.safeParse(filteredMcpTools) + if (!validation.success) { + console.error('MCP tools validation error:', validation.error) throw new Error('Internal error: MCP tools validation failed') } - return validatedTools + return validation.data } diff --git a/apps/studio/lib/ai/tools/rendering-tools.ts b/apps/studio/lib/ai/tools/rendering-tools.ts index c94a981daf..96d06f631e 100644 --- a/apps/studio/lib/ai/tools/rendering-tools.ts +++ b/apps/studio/lib/ai/tools/rendering-tools.ts @@ -2,47 +2,25 @@ import { tool } from 'ai' import { z } from 'zod' export const getRenderingTools = () => ({ - display_query: tool({ - description: - 'Displays SQL query results (table or chart) or renders SQL for write/DDL operations. Use this for all query display needs. Optionally references a previous execute_sql call via manualToolCallId for displaying SELECT results.', + execute_sql: tool({ + description: 'Asks the user to execute a SQL statement and return the results', inputSchema: z.object({ - manualToolCallId: z - .string() - .optional() - .describe('The manual ID from the corresponding execute_sql result (for SELECT queries).'), - sql: z.string().describe('The SQL query.'), - label: z - .string() + sql: z.string().describe('The SQL statement to execute.'), + label: z.string().describe('A short 2-4 word label for the SQL statement.'), + isWriteQuery: z + .boolean() .describe( - 'The title or label for this query block (e.g., "Users Over Time", "Create Users Table").' + 'Whether the SQL statement performs a write operation of any kind instead of a read operation' ), - view: z - .enum(['table', 'chart']) - .optional() - .describe( - 'Display mode for SELECT results: table or chart. Required if manualToolCallId is provided.' - ), - xAxis: z.string().optional().describe('Key for the x-axis (required if view is chart).'), - yAxis: z.string().optional().describe('Key for the y-axis (required if view is chart).'), }), - execute: async (args) => { - const statusMessage = args.manualToolCallId - ? 'Tool call sent to client for rendering SELECT results.' - : 'Tool call sent to client for rendering write/DDL query.' - return { status: statusMessage } - }, }), - display_edge_function: tool({ - description: 'Renders the code for a Supabase Edge Function for the user to deploy manually.', + deploy_edge_function: tool({ + description: + 'Ask the user to deploy a Supabase Edge Function from provided code on the client. Client will confirm before deploying and return the result', inputSchema: z.object({ - name: z - .string() - .describe('The URL-friendly name of the Edge Function (e.g., "my-function").'), + name: z.string().describe('The URL-friendly name/slug of the Edge Function.'), code: z.string().describe('The TypeScript code for the Edge Function.'), }), - execute: async () => { - return { status: 'Tool call sent to client for rendering.' } - }, }), rename_chat: tool({ description: `Rename the current chat session when the current chat name doesn't describe the conversation topic.`, diff --git a/apps/studio/lib/ai/tools/tool-sanitizer.test.ts b/apps/studio/lib/ai/tools/tool-sanitizer.test.ts new file mode 100644 index 0000000000..c8c7600b7a --- /dev/null +++ b/apps/studio/lib/ai/tools/tool-sanitizer.test.ts @@ -0,0 +1,175 @@ +import type { ToolUIPart } from 'ai' +import { describe, expect, test } from 'vitest' +// End of third-party imports + +import { prepareMessagesForAPI } from '../message-utils' +import { + createAssistantMessageWithExecuteSqlTool, + createAssistantMessageWithMultipleTools, + createLongConversation, +} from '../test-fixtures' +import { NO_DATA_PERMISSIONS, sanitizeMessagePart } from './tool-sanitizer' + +describe('messages are sanitized based on opt-in level', () => { + test('messages are sanitized at disabled level', () => { + const messages = [ + createAssistantMessageWithExecuteSqlTool('SELECT email FROM users', [ + { email: 'test@example.com' }, + ]), + ] + + // Prepare messages as frontend would + const preparedMessages = prepareMessagesForAPI(messages) + + // Sanitize messages as API endpoint would + const processedMessages = preparedMessages.map((msg) => { + if (msg.role === 'assistant' && msg.parts) { + const processedParts = msg.parts.map((part) => { + return sanitizeMessagePart(part, 'disabled') + }) + + return { ...msg, parts: processedParts } + } + return msg + }) + + const output = (processedMessages[0].parts[1] as ToolUIPart).output + expect(output).toMatch(NO_DATA_PERMISSIONS) + }) + + test('messages are sanitized at schema level', () => { + const messages = [ + createAssistantMessageWithExecuteSqlTool('SELECT email FROM users', [ + { email: 'test@example.com' }, + ]), + ] + + // Prepare messages as frontend would + const preparedMessages = prepareMessagesForAPI(messages) + + // Sanitize messages as API endpoint would + const processedMessages = preparedMessages.map((msg) => { + if (msg.role === 'assistant' && msg.parts) { + const processedParts = msg.parts.map((part) => { + return sanitizeMessagePart(part, 'schema') + }) + + return { ...msg, parts: processedParts } + } + return msg + }) + + const output = (processedMessages[0].parts[1] as ToolUIPart).output + expect(output).toMatch(NO_DATA_PERMISSIONS) + }) + + test('messages are sanitized at schema and log level', () => { + const messages = [ + createAssistantMessageWithExecuteSqlTool('SELECT email FROM users', [ + { email: 'test@example.com' }, + ]), + ] + + // Prepare messages as frontend would + const preparedMessages = prepareMessagesForAPI(messages) + + // Sanitize messages as API endpoint would + const processedMessages = preparedMessages.map((msg) => { + if (msg.role === 'assistant' && msg.parts) { + const processedParts = msg.parts.map((part) => { + return sanitizeMessagePart(part, 'schema_and_log') + }) + + return { ...msg, parts: processedParts } + } + return msg + }) + + const output = (processedMessages[0].parts[1] as ToolUIPart).output + expect(output).toMatch(NO_DATA_PERMISSIONS) + }) + + test('messages are not sanitized at data level', () => { + const messages = [ + createAssistantMessageWithExecuteSqlTool('SELECT email FROM users', [ + { email: 'test@example.com' }, + ]), + ] + + // Prepare messages as frontend would + const preparedMessages = prepareMessagesForAPI(messages) + + // Sanitize messages as API endpoint would + const processedMessages = preparedMessages.map((msg) => { + if (msg.role === 'assistant' && msg.parts) { + const processedParts = msg.parts.map((part) => { + return sanitizeMessagePart(part, 'schema_and_log_and_data') + }) + + return { ...msg, parts: processedParts } + } + return msg + }) + + const output = (processedMessages[0].parts[1] as ToolUIPart).output + expect(output).toEqual([{ email: 'test@example.com' }]) + }) + + test('multiple tool parts in message are sanitized', () => { + const messages = [createAssistantMessageWithMultipleTools()] + + // Prepare messages as frontend would + const preparedMessages = prepareMessagesForAPI(messages) + + // Sanitize messages as API endpoint would + const processedMessages = preparedMessages.map((msg) => { + if (msg.role === 'assistant' && msg.parts) { + const processedParts = msg.parts.map((part) => { + return sanitizeMessagePart(part, 'schema') + }) + + return { ...msg, parts: processedParts } + } + return msg + }) + + const parts = processedMessages[0].parts + parts.forEach((part) => { + if (part.type.startsWith('tool')) { + const tool = part as ToolUIPart + expect(tool.output).toMatch(NO_DATA_PERMISSIONS) + } + }) + }) + + test('long message chain is sanitized', () => { + const messages = createLongConversation() + + // Prepare messages as frontend would + const preparedMessages = prepareMessagesForAPI(messages) + + // Sanitize messages as API endpoint would + const processedMessages = preparedMessages.map((msg) => { + if (msg.role === 'assistant' && msg.parts) { + const processedParts = msg.parts.map((part) => { + return sanitizeMessagePart(part, 'schema') + }) + + return { ...msg, parts: processedParts } + } + return msg + }) + + processedMessages.forEach((msg) => { + if (msg.role === 'assistant' && msg.parts) { + const parts = msg.parts + parts.forEach((part) => { + if (part.type.startsWith('tool')) { + const tool = part as ToolUIPart + expect(tool.output).toMatch(NO_DATA_PERMISSIONS) + } + }) + } + }) + }) +}) diff --git a/apps/studio/lib/ai/tools/tool-sanitizer.ts b/apps/studio/lib/ai/tools/tool-sanitizer.ts new file mode 100644 index 0000000000..9ea3d5c255 --- /dev/null +++ b/apps/studio/lib/ai/tools/tool-sanitizer.ts @@ -0,0 +1,54 @@ +import type { ToolUIPart, UIMessage } from 'ai' +// End of third-party imports + +import type { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' +import type { ToolName } from '../tool-filter' + +interface ToolSanitizer { + toolName: ToolName + sanitize: (tool: Tool, optInLevel: AiOptInLevel) => Tool +} + +export const NO_DATA_PERMISSIONS = + 'The query was executed and the user has viewed the results but decided not to share in the conversation due to permission levels. Continue with your plan unless instructed to interpret the result.' + +const executeSqlSanitizer: ToolSanitizer = { + toolName: 'execute_sql', + sanitize: (tool, optInLevel) => { + const output = tool.output + let sanitizedOutput: unknown + + if (optInLevel !== 'schema_and_log_and_data') { + if (Array.isArray(output)) { + sanitizedOutput = NO_DATA_PERMISSIONS + } + } else { + sanitizedOutput = output + } + + return { + ...tool, + output: sanitizedOutput, + } + }, +} + +export const ALL_TOOL_SANITIZERS = { + [executeSqlSanitizer.toolName]: executeSqlSanitizer, +} + +export function sanitizeMessagePart( + part: UIMessage['parts'][number], + optInLevel: AiOptInLevel +): UIMessage['parts'][number] { + if (part.type.startsWith('tool-')) { + const toolPart = part as ToolUIPart + const toolName = toolPart.type.slice('tool-'.length) + const sanitizer = ALL_TOOL_SANITIZERS[toolName] + if (sanitizer) { + return sanitizer.sanitize(toolPart, optInLevel) + } + } + + return part +} diff --git a/apps/studio/lib/api/generate-v4.test.ts b/apps/studio/lib/api/generate-v4.test.ts new file mode 100644 index 0000000000..a04be6ef01 --- /dev/null +++ b/apps/studio/lib/api/generate-v4.test.ts @@ -0,0 +1,77 @@ +import { expect, test, vi } from 'vitest' +// End of third-party imports + +import generateV4 from '../../pages/api/ai/sql/generate-v4' +import { sanitizeMessagePart } from '../ai/tools/tool-sanitizer' + +vi.mock('../ai/tools/tool-sanitizer', () => ({ + sanitizeMessagePart: vi.fn((part) => part), +})) + +test('generateV4 calls the tool sanitizer', async () => { + const mockReq = { + method: 'POST', + headers: { + authorization: 'Bearer test-token', + }, + body: { + messages: [ + { + role: 'assistant', + parts: [ + { + type: 'tool-execute_sql', + state: 'output-available', + output: 'test output', + }, + ], + }, + ], + projectRef: 'test-project', + connectionString: 'test-connection', + orgSlug: 'test-org', + }, + } + + const mockRes = { + status: vi.fn(() => mockRes), + json: vi.fn(() => mockRes), + setHeader: vi.fn(() => mockRes), + } + + vi.mock('lib/ai/org-ai-details', () => ({ + getOrgAIDetails: vi.fn().mockResolvedValue({ + aiOptInLevel: 'schema_and_log_and_data', + isLimited: false, + }), + })) + + vi.mock('lib/ai/model', () => ({ + getModel: vi.fn().mockResolvedValue({ + model: {}, + error: null, + promptProviderOptions: {}, + providerOptions: {}, + }), + })) + + vi.mock('data/sql/execute-sql-query', () => ({ + executeSql: vi.fn().mockResolvedValue({ result: [] }), + })) + + vi.mock('lib/ai/tools', () => ({ + getTools: vi.fn().mockResolvedValue({}), + })) + + vi.mock('ai', () => ({ + streamText: vi.fn().mockReturnValue({ + pipeUIMessageStreamToResponse: vi.fn(), + }), + convertToModelMessages: vi.fn((msgs) => msgs), + stepCountIs: vi.fn(), + })) + + await generateV4(mockReq as any, mockRes as any) + + expect(sanitizeMessagePart).toHaveBeenCalled() +}) diff --git a/apps/studio/lib/profile.tsx b/apps/studio/lib/profile.tsx index 5ecf4617b7..87de9d8068 100644 --- a/apps/studio/lib/profile.tsx +++ b/apps/studio/lib/profile.tsx @@ -6,11 +6,13 @@ import { toast } from 'sonner' import { useIsLoggedIn, useUser } from 'common' import { usePermissionsQuery } from 'data/permissions/permissions-query' import { useProfileCreateMutation } from 'data/profile/profile-create-mutation' +import { useProfileIdentitiesQuery } from 'data/profile/profile-identities-query' import { useProfileQuery } from 'data/profile/profile-query' import type { Profile } from 'data/profile/types' import { useSendEventMutation } from 'data/telemetry/send-event-mutation' import type { ResponseError } from 'types' import { useSignOut } from './auth' +import { getGitHubProfileImgUrl } from './github' export type ProfileContextType = { profile: Profile | undefined @@ -117,3 +119,28 @@ export const ProfileProvider = ({ children }: PropsWithChildren<{}>) => { } export const useProfile = () => useContext(ProfileContext) + +export function useProfileNameAndPicture(): { + username?: string + primaryEmail?: string + avatarUrl?: string + isLoading: boolean +} { + const { profile, isLoading: isLoadingProfile } = useProfile() + const { data: identitiesData, isLoading: isLoadingIdentities } = useProfileIdentitiesQuery() + + const username = profile?.username + const isGitHubProfile = profile?.auth0_id.startsWith('github') + + const gitHubUsername = isGitHubProfile + ? identitiesData?.identities.find((x) => x.provider === 'github')?.identity_data?.user_name + : undefined + const avatarUrl = isGitHubProfile ? getGitHubProfileImgUrl(gitHubUsername) : undefined + + return { + username: profile?.username, + primaryEmail: profile?.primary_email, + avatarUrl, + isLoading: isLoadingProfile || isLoadingIdentities, + } +} diff --git a/apps/studio/pages/api/ai/sql/generate-v4.ts b/apps/studio/pages/api/ai/sql/generate-v4.ts index 4614fa1d61..91ffea6e84 100644 --- a/apps/studio/pages/api/ai/sql/generate-v4.ts +++ b/apps/studio/pages/api/ai/sql/generate-v4.ts @@ -1,18 +1,14 @@ import pgMeta from '@supabase/pg-meta' -import { convertToModelMessages, ModelMessage, stepCountIs, streamText } from 'ai' +import { convertToModelMessages, type ModelMessage, stepCountIs, streamText } from 'ai' import { source } from 'common-tags' -import { NextApiRequest, NextApiResponse } from 'next' -import { z } from 'zod/v4' -import { z as z3 } from 'zod/v3' +import type { NextApiRequest, NextApiResponse } from 'next' +import z from 'zod' import { IS_PLATFORM } from 'common' import { executeSql } from 'data/sql/execute-sql-query' -import { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' +import type { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi' import { getModel } from 'lib/ai/model' import { getOrgAIDetails } from 'lib/ai/org-ai-details' -import { getTools } from 'lib/ai/tools' -import apiWrapper from 'lib/api/apiWrapper' - import { CHAT_PROMPT, EDGE_FUNCTION_PROMPT, @@ -21,6 +17,9 @@ import { RLS_PROMPT, SECURITY_PROMPT, } from 'lib/ai/prompts' +import { getTools } from 'lib/ai/tools' +import { sanitizeMessagePart } from 'lib/ai/tools/tool-sanitizer' +import apiWrapper from 'lib/api/apiWrapper' import { executeQuery } from 'lib/api/self-hosted/query' export const maxDuration = 120 @@ -37,7 +36,10 @@ async function handler(req: NextApiRequest, res: NextApiResponse) { return handlePost(req, res) default: res.setHeader('Allow', ['POST']) - res.status(405).json({ data: null, error: { message: `Method ${method} Not Allowed` } }) + res.status(405).json({ + data: null, + error: { message: `Method ${method} Not Allowed` }, + }) } } @@ -92,9 +94,9 @@ async function handlePost(req: NextApiRequest, res: NextApiResponse) { aiOptInLevel = orgAIOptInLevel isLimited = orgAILimited } catch (error) { - return res - .status(400) - .json({ error: 'There was an error fetching your organization details' }) + return res.status(400).json({ + error: 'There was an error fetching your organization details', + }) } } @@ -108,13 +110,17 @@ async function handlePost(req: NextApiRequest, res: NextApiResponse) { return cleanedMsg } if (msg && msg.role === 'assistant' && msg.parts) { - const cleanedParts = msg.parts.filter((part: any) => { - if (part.type.startsWith('tool-')) { - const invalidStates = ['input-streaming', 'input-available', 'output-error'] - return !invalidStates.includes(part.state) - } - return true - }) + const cleanedParts = msg.parts + .filter((part: any) => { + if (part.type.startsWith('tool-')) { + const invalidStates = ['input-streaming', 'input-available', 'output-error'] + return !invalidStates.includes(part.state) + } + return true + }) + .map((part: any) => { + return sanitizeMessagePart(part, aiOptInLevel) + }) return { ...msg, parts: cleanedParts } } return msg @@ -139,7 +145,7 @@ async function handlePost(req: NextApiRequest, res: NextApiResponse) { try { // Get a list of all schemas to add to context const pgMetaSchemasList = pgMeta.schemas.list() - type Schemas = z3.infer<(typeof pgMetaSchemasList)['zod']> + type Schemas = z.infer<(typeof pgMetaSchemasList)['zod']> const { result: schemas } = aiOptInLevel !== 'disabled' @@ -179,7 +185,9 @@ async function handlePost(req: NextApiRequest, res: NextApiResponse) { { role: 'system', content: system, - ...(promptProviderOptions && { providerOptions: promptProviderOptions }), + ...(promptProviderOptions && { + providerOptions: promptProviderOptions, + }), }, { role: 'assistant', diff --git a/apps/studio/state/ai-assistant-state.tsx b/apps/studio/state/ai-assistant-state.tsx index 4d2836c070..cb138eebb8 100644 --- a/apps/studio/state/ai-assistant-state.tsx +++ b/apps/studio/state/ai-assistant-state.tsx @@ -104,16 +104,33 @@ async function clearStorage(): Promise { } } +// Helper function to sanitize objects to ensure they're cloneable +// Issue due to addToolResult +function sanitizeForCloning(obj: any): any { + if (obj === null || obj === undefined) return obj + if (typeof obj !== 'object') return obj + return JSON.parse(JSON.stringify(obj)) +} + // Helper function to load state from IndexedDB async function loadFromIndexedDB(projectRef: string): Promise { try { const persistedState = await getAiState(projectRef) if (persistedState) { - // Revive dates + // Revive dates and sanitize message data Object.values(persistedState.chats).forEach((chat: ChatSession) => { if (chat && typeof chat === 'object') { chat.createdAt = new Date(chat.createdAt) chat.updatedAt = new Date(chat.updatedAt) + + // Sanitize message parts to remove proxy objects + if (chat.messages) { + chat.messages.forEach((message: any) => { + if (message.parts) { + message.parts = message.parts.map((part: any) => sanitizeForCloning(part)) + } + }) + } } }) return persistedState @@ -321,15 +338,19 @@ export const createAiAssistantState = (): AiAssistantState => { const chat = state.activeChat if (!chat) return - const existingMessages = chat.messages - const messagesToAdd = Array.isArray(message) - ? message.filter( - (msg) => - !existingMessages.some((existing: AssistantMessageType) => existing.id === msg.id) - ) - : !existingMessages.some((existing: AssistantMessageType) => existing.id === message.id) - ? [message] - : [] + const incomingMessages = Array.isArray(message) ? message : [message] + + const messagesToAdd: AssistantMessageType[] = [] + + incomingMessages.forEach((msg) => { + const index = chat.messages.findIndex((existing) => existing.id === msg.id) + + if (index !== -1) { + state.updateMessage(msg) + } else { + messagesToAdd.push(msg as AssistantMessageType) + } + }) if (messagesToAdd.length > 0) { chat.messages.push(...messagesToAdd) @@ -337,26 +358,14 @@ export const createAiAssistantState = (): AiAssistantState => { } }, - updateMessage: ({ - id, - resultId, - results, - }: { - id: string - resultId?: string - results: any[] - }) => { + updateMessage: (updatedMessage: MessageType) => { const chat = state.activeChat - if (!chat || !resultId) return - - const messageIndex = chat.messages.findIndex((msg) => msg.id === id) + if (!chat) return + const messageIndex = chat.messages.findIndex((msg) => msg.id === updatedMessage.id) if (messageIndex !== -1) { - const msg = chat.messages[messageIndex] - if (!msg.results) { - msg.results = {} - } - msg.results[resultId] = results + chat.messages[messageIndex] = updatedMessage as AssistantMessageType + chat.updatedAt = new Date() } }, @@ -435,7 +444,7 @@ export type AiAssistantState = AiAssistantData & { clearMessages: () => void deleteMessagesAfter: (id: string, options?: { includeSelf?: boolean }) => void saveMessage: (message: MessageType | MessageType[]) => void - updateMessage: (args: { id: string; resultId?: string; results: any[] }) => void + updateMessage: (message: MessageType) => void setSqlSnippets: (snippets: SqlSnippet[]) => void clearSqlSnippets: () => void getCachedSQLResults: (args: { messageId: string; snippetId?: string }) => any[] | undefined