From d02270d5bbf8e123f4ad91c7cfffec39b68d8f56 Mon Sep 17 00:00:00 2001 From: Lim Hoang Date: Wed, 19 Jan 2022 17:25:27 +0000 Subject: [PATCH] [KED-3002] Live update runs list from subscription (#703) --- RELEASE.md | 1 + package-lock.json | 37 ++++++- package.json | 1 + package/kedro_viz/api/graphql.py | 43 ++++++-- .../data_access/repositories/runs.py | 18 +++- package/setup.cfg | 1 + package/test_requirements.txt | 1 + .../test_graphql/test_subscriptions.py | 102 ++++++++++++++++++ src/apollo/config.js | 38 ++++++- src/apollo/queries.js | 2 +- src/apollo/schema.graphql | 4 + src/apollo/subscriptions.js | 14 +++ src/apollo/utils.js | 7 +- .../experiment-wrapper/experiment-wrapper.js | 23 +++- 14 files changed, 275 insertions(+), 17 deletions(-) create mode 100644 package/tests/test_api/test_graphql/test_subscriptions.py create mode 100644 src/apollo/subscriptions.js diff --git a/RELEASE.md b/RELEASE.md index ed8c7ca582..0fe582d990 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,7 @@ Please follow the established format: ## Major features and improvements - Create the toggle-bookmark journey that allows bookmarking runs and displaying them as a separate list. (#689) +- Setup subscription for auto update of experiment runs list on new Kedro runs. (#703) ## Bug fixes and other changes diff --git a/package-lock.json b/package-lock.json index 863f6214a1..2e66e55b60 100644 --- a/package-lock.json +++ b/package-lock.json @@ -5036,6 +5036,11 @@ "integrity": "sha512-q/UEjfGJ2Cm3oKV71DJz9d25TPnq5rhBVL2Q4fA5wcC3jcrdn7+SssEybFIxwAvvP+YCsCYNKughoF33GxgycQ==", "dev": true }, + "backo2": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/backo2/-/backo2-1.0.2.tgz", + "integrity": "sha1-MasayLEpNjRj41s+u2n038+6eUc=" + }, "bail": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/bail/-/bail-1.0.5.tgz", @@ -11463,6 +11468,11 @@ "istanbul-lib-report": "^3.0.0" } }, + "iterall": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/iterall/-/iterall-1.3.0.tgz", + "integrity": "sha512-QZ9qOMdF+QLHxy1QIpUHUU1D5pS2CG2P69LF6L6CPjPYA/XMOmKV3PZpawHoAjHNyB0swdVTRxdYT4tbBbxqwg==" + }, "jest": { "version": "26.6.0", "resolved": "https://registry.npmjs.org/jest/-/jest-26.6.0.tgz", @@ -21777,6 +21787,30 @@ "prettier-linter-helpers": "^1.0.0" } }, + "subscriptions-transport-ws": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/subscriptions-transport-ws/-/subscriptions-transport-ws-0.11.0.tgz", + "integrity": "sha512-8D4C6DIH5tGiAIpp5I0wD/xRlNiZAPGHygzCe7VzyzUoxHtawzjNAY9SUTXU05/EY2NMY9/9GF0ycizkXr1CWQ==", + "requires": { + "backo2": "^1.0.2", + "eventemitter3": "^3.1.0", + "iterall": "^1.2.1", + "symbol-observable": "^1.0.4", + "ws": "^5.2.0 || ^6.0.0 || ^7.0.0" + }, + "dependencies": { + "eventemitter3": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-3.1.2.tgz", + "integrity": "sha512-tvtQIeLVHjDkJYnzf2dgVMxfuSGJeM/7UCG17TT4EumTfNtF+0nebF/4zWOIkCreAbtNqhGEboB6BWrwqNaw4Q==" + }, + "symbol-observable": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/symbol-observable/-/symbol-observable-1.2.0.tgz", + "integrity": "sha512-e900nM8RRtGhlV36KGEU9k65K3mPb1WV70OdjfxlG2EAuM1noi/E/BaW/uMhL7bPEssK8QV57vN3esixjUvcXQ==" + } + } + }, "sugarss": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/sugarss/-/sugarss-2.0.0.tgz", @@ -24600,8 +24634,7 @@ "ws": { "version": "7.5.5", "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.5.tgz", - "integrity": "sha512-BAkMFcAzl8as1G/hArkxOxq3G7pjUqQ3gzYbLL0/5zNkph70e+lCoxBGnm6AW1+/aiNeV4fnKqZ8m4GZewmH2w==", - "dev": true + "integrity": "sha512-BAkMFcAzl8as1G/hArkxOxq3G7pjUqQ3gzYbLL0/5zNkph70e+lCoxBGnm6AW1+/aiNeV4fnKqZ8m4GZewmH2w==" }, "xml-name-validator": { "version": "3.0.0", diff --git a/package.json b/package.json index 53fcd459b8..06a9070b08 100644 --- a/package.json +++ b/package.json @@ -79,6 +79,7 @@ "reselect": "^4.1.5", "seedrandom": "^3.0.5", "sinon": "^12.0.1", + "subscriptions-transport-ws": "^0.11.0", "svg-crowbar": "^0.6.5", "what-input": "^5.2.10" }, diff --git a/package/kedro_viz/api/graphql.py b/package/kedro_viz/api/graphql.py index 8ed5475794..681a5df412 100644 --- a/package/kedro_viz/api/graphql.py +++ b/package/kedro_viz/api/graphql.py @@ -3,11 +3,21 @@ from __future__ import annotations +import asyncio import json import logging from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Dict, Iterable, List, NewType, Optional, cast +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + Dict, + Iterable, + List, + NewType, + Optional, + cast, +) import strawberry from fastapi import APIRouter @@ -115,7 +125,13 @@ def get_all_runs() -> List[Run]: Returns: list of Run objects """ - return format_runs(data_access_manager.runs.get_all_runs()) + all_runs = data_access_manager.runs.get_all_runs() + if not all_runs: + return [] + all_run_ids = [run.id for run in all_runs] + return format_runs( + all_runs, data_access_manager.runs.get_user_run_details_by_run_ids(all_run_ids) + ) def format_run_tracking_data( @@ -256,8 +272,21 @@ class Subscription: """Subscription object to track runs added in real time""" @strawberry.subscription - def run_added(self, run_id: ID) -> Run: - """Subscription to add runs in real-time""" + async def runs_added(self) -> AsyncGenerator[List[Run], None]: + """Subscription to new runs in real-time""" + while True: + new_runs = data_access_manager.runs.get_new_runs() + if new_runs: + data_access_manager.runs.last_run_id = new_runs[0].id + yield [ + format_run( + run.id, + json.loads(run.blob), + data_access_manager.runs.get_user_run_details(run.id), + ) + for run in new_runs + ] + await asyncio.sleep(3) # pragma: no cover @strawberry.type @@ -348,8 +377,10 @@ def update_run_details(self, run_id: ID, run_input: RunInput) -> Response: return UpdateRunDetailsSuccess(updated_run) -schema = strawberry.Schema(query=Query, mutation=Mutation) +schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription) router = APIRouter() -router.add_route("/graphql", GraphQL(schema)) +graphql_app = GraphQL(schema) +router.add_route("/graphql", graphql_app) +router.add_websocket_route("/graphql", graphql_app) diff --git a/package/kedro_viz/data_access/repositories/runs.py b/package/kedro_viz/data_access/repositories/runs.py index ba688b6882..0876bf23c9 100644 --- a/package/kedro_viz/data_access/repositories/runs.py +++ b/package/kedro_viz/data_access/repositories/runs.py @@ -27,9 +27,11 @@ def func(self: "RunsRepository", *method_args, **method_kwargs): class RunsRepository: _db_session_class: Optional[sessionmaker] + last_run_id: Optional[str] def __init__(self, db_session_class: Optional[sessionmaker] = None): self._db_session_class = db_session_class + self.last_run_id = None def set_db_session(self, db_session_class: sessionmaker): """Sqlite db connection session""" @@ -42,12 +44,15 @@ def add_run(self, run: RunModel): @check_db_session def get_all_runs(self) -> Optional[Iterable[RunModel]]: - return ( + all_runs = ( self._db_session_class() # type: ignore .query(RunModel) .order_by(RunModel.id.desc()) .all() ) + if all_runs: + self.last_run_id = all_runs[0].id + return all_runs @check_db_session def get_run_by_id(self, run_id: str) -> Optional[RunModel]: @@ -71,6 +76,16 @@ def get_user_run_details(self, run_id: str) -> Optional[UserRunDetailsModel]: .first() ) + @check_db_session + def get_new_runs(self) -> Optional[Iterable[RunModel]]: + query = self._db_session_class().query(RunModel) # type: ignore + + if self.last_run_id: + # TODO: change this query to use timestamp once we have added that column + query = query.filter(RunModel.id > self.last_run_id) + + return query.order_by(RunModel.id.desc()).all() + @check_db_session def get_user_run_details_by_run_ids( self, run_ids: List[str] @@ -102,5 +117,4 @@ def create_or_update_user_run_details( user_run_details.title = title user_run_details.bookmark = bookmark user_run_details.notes = notes - return user_run_details diff --git a/package/setup.cfg b/package/setup.cfg index 1cef0d5db5..41b22260f8 100644 --- a/package/setup.cfg +++ b/package/setup.cfg @@ -9,3 +9,4 @@ addopts=--verbose -ra --ignore package/tests --no-cov-on-fail -ra + --asyncio-mode auto diff --git a/package/test_requirements.txt b/package/test_requirements.txt index 4db7e918bf..7dc9013555 100644 --- a/package/test_requirements.txt +++ b/package/test_requirements.txt @@ -10,6 +10,7 @@ mypy~=0.930 psutil==5.6.6 # same as Kedro for now pylint~=2.8.2 pytest~=6.2.0 +pytest-asyncio~=0.17.2 pytest-cov~=2.11.1 pytest-mock~=3.6.1 sqlalchemy-stubs~=0.4 diff --git a/package/tests/test_api/test_graphql/test_subscriptions.py b/package/tests/test_api/test_graphql/test_subscriptions.py new file mode 100644 index 0000000000..09193c2953 --- /dev/null +++ b/package/tests/test_api/test_graphql/test_subscriptions.py @@ -0,0 +1,102 @@ +import json + +from kedro_viz.api.graphql import schema +from kedro_viz.models.experiments_tracking import RunModel + + +class TestRunsAddedSubscription: + async def test_runs_added_subscription_with_no_existing_run( + self, data_access_manager_with_no_run + ): + query = """ + subscription { + runsAdded { + id + bookmark + gitSha + timestamp + title + } + } + """ + + # start subscription + subscription = await schema.subscribe(query) + + example_new_run_id = "test_id" + run = RunModel( + id=example_new_run_id, + blob=json.dumps( + { + "session_id": example_new_run_id, + "cli": {"command_path": "kedro run"}, + } + ), + ) + data_access_manager_with_no_run.runs.add_run(run) + + # assert subscription result + async for result in subscription: + assert not result.errors + assert result.data == { + "runsAdded": [ + { + "id": example_new_run_id, + "bookmark": False, + "gitSha": None, + "timestamp": example_new_run_id, + "title": example_new_run_id, + } + ] + } + break + + async def test_runs_added_subscription_with_existing_runs( + self, data_access_manager_with_runs + ): + query = """ + subscription { + runsAdded { + id + bookmark + gitSha + timestamp + title + } + } + """ + all_runs = data_access_manager_with_runs.runs.get_all_runs() + assert all_runs + + # start subscription + subscription = await schema.subscribe(query) + + # add a new run + example_new_run_id = "new_run" + run = RunModel( + id=example_new_run_id, + blob=json.dumps( + { + "session_id": example_new_run_id, + "cli": {"command_path": "kedro run"}, + } + ), + ) + data_access_manager_with_runs.runs.add_run(run) + + # assert subscription result + async for result in subscription: + assert not result.errors + assert result.data == { + "runsAdded": [ + { + "id": example_new_run_id, + "bookmark": False, + "gitSha": None, + "timestamp": example_new_run_id, + "title": example_new_run_id, + } + ] + } + assert data_access_manager_with_runs.runs.last_run_id == example_new_run_id + break diff --git a/src/apollo/config.js b/src/apollo/config.js index 90a94c4e32..9cbe73c32a 100644 --- a/src/apollo/config.js +++ b/src/apollo/config.js @@ -1,15 +1,47 @@ import fetch from 'cross-fetch'; -import { ApolloClient, InMemoryCache, createHttpLink } from '@apollo/client'; +import { + ApolloClient, + InMemoryCache, + createHttpLink, + split, +} from '@apollo/client'; +import { getMainDefinition } from '@apollo/client/utilities'; +import { WebSocketLink } from '@apollo/client/link/ws'; -const link = createHttpLink({ +const wsHost = + process.env.NODE_ENV === 'development' + ? 'localhost:4142' + : window.location.host; + +const wsLink = new WebSocketLink({ + uri: `ws://${wsHost}/graphql`, + options: { + reconnect: true, + }, +}); + +const httpLink = createHttpLink({ // our graphql endpoint, normally here: http://localhost:4141/graphql uri: '/graphql', fetch, }); +const splitLink = split( + ({ query }) => { + const definition = getMainDefinition(query); + + return ( + definition.kind === 'OperationDefinition' && + definition.operation === 'subscription' + ); + }, + wsLink, + httpLink +); + export const client = new ApolloClient({ connectToDevTools: true, - link, + link: splitLink, cache: new InMemoryCache(), defaultOptions: { query: { diff --git a/src/apollo/queries.js b/src/apollo/queries.js index f59229838c..850b0918e8 100644 --- a/src/apollo/queries.js +++ b/src/apollo/queries.js @@ -17,11 +17,11 @@ export const GET_RUNS = gql` export const GET_RUN_METADATA = gql` query getRunMetadata($runIds: [ID!]!) { runMetadata(runIds: $runIds) { + id author bookmark gitBranch gitSha - id notes runCommand timestamp diff --git a/src/apollo/schema.graphql b/src/apollo/schema.graphql index 9397a7c916..d0472778e4 100644 --- a/src/apollo/schema.graphql +++ b/src/apollo/schema.graphql @@ -29,6 +29,10 @@ input RunInput { notes: String = null } +type Subscription { + runsAdded: [Run!]! +} + type TrackingDataset { datasetName: String datasetType: String diff --git a/src/apollo/subscriptions.js b/src/apollo/subscriptions.js new file mode 100644 index 0000000000..7948036a11 --- /dev/null +++ b/src/apollo/subscriptions.js @@ -0,0 +1,14 @@ +import gql from 'graphql-tag'; + +/** subscribe to receive new runs */ +export const NEW_RUN_SUBSCRIPTION = gql` + subscription { + runsAdded { + id + bookmark + gitSha + timestamp + title + } + } +`; diff --git a/src/apollo/utils.js b/src/apollo/utils.js index b9cb006948..fedb3804f4 100644 --- a/src/apollo/utils.js +++ b/src/apollo/utils.js @@ -8,7 +8,10 @@ import { useQuery } from '@apollo/client'; */ export const useApolloQuery = (query, options) => { const [data, setData] = useState(undefined); - const { data: queryData, error, loading } = useQuery(query, options); + const { subscribeToMore, data: queryData, error, loading } = useQuery( + query, + options + ); useEffect(() => { if (queryData !== undefined) { @@ -16,5 +19,5 @@ export const useApolloQuery = (query, options) => { } }, [queryData]); - return { data, error, loading }; + return { subscribeToMore, data, error, loading }; }; diff --git a/src/components/experiment-wrapper/experiment-wrapper.js b/src/components/experiment-wrapper/experiment-wrapper.js index 296597b5eb..9681962f9a 100644 --- a/src/components/experiment-wrapper/experiment-wrapper.js +++ b/src/components/experiment-wrapper/experiment-wrapper.js @@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react'; import { useApolloQuery } from '../../apollo/utils'; import { connect } from 'react-redux'; import { GET_RUNS } from '../../apollo/queries'; +import { NEW_RUN_SUBSCRIPTION } from '../../apollo/subscriptions'; import { sortRunByTime } from '../../utils/date-utils'; import Button from '@quantumblack/kedro-ui/lib/components/button'; import Details from '../experiment-tracking/details'; @@ -21,7 +22,7 @@ const ExperimentWrapper = ({ theme }) => { const [selectedRunData, setSelectedRunData] = useState(null); const [showRunDetailsModal, setShowRunDetailsModal] = useState(false); - const { data, loading } = useApolloQuery(GET_RUNS); + const { subscribeToMore, data, loading } = useApolloQuery(GET_RUNS); const onRunSelection = (id) => { if (enableComparisonView) { @@ -104,6 +105,26 @@ const ExperimentWrapper = ({ theme }) => { } }, [selectedRunIds, pinnedRun]); + useEffect(() => { + if (!data?.runsList || data.runsList.length === 0) { + return; + } + + subscribeToMore({ + document: NEW_RUN_SUBSCRIPTION, + updateQuery: (prev, { subscriptionData }) => { + if (!subscriptionData.data || !prev?.runsList) { + return prev; + } + const newRuns = subscriptionData.data.runsAdded; + + return Object.assign({}, prev, { + runsList: [...newRuns, ...prev.runsList], + }); + }, + }); + }, [data, subscribeToMore]); + if (loading) { return (