Skip to content

Commit

Permalink
[KED-3002] Live update runs list from subscription (#703)
Browse files Browse the repository at this point in the history
  • Loading branch information
limdauto authored Jan 19, 2022
1 parent f49d174 commit d02270d
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 17 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Please follow the established format:
## Major features and improvements

- Create the toggle-bookmark journey that allows bookmarking runs and displaying them as a separate list. (#689)
- Setup subscription for auto update of experiment runs list on new Kedro runs. (#703)

## Bug fixes and other changes

Expand Down
37 changes: 35 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"reselect": "^4.1.5",
"seedrandom": "^3.0.5",
"sinon": "^12.0.1",
"subscriptions-transport-ws": "^0.11.0",
"svg-crowbar": "^0.6.5",
"what-input": "^5.2.10"
},
Expand Down
43 changes: 37 additions & 6 deletions package/kedro_viz/api/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,21 @@

from __future__ import annotations

import asyncio
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterable, List, NewType, Optional, cast
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Dict,
Iterable,
List,
NewType,
Optional,
cast,
)

import strawberry
from fastapi import APIRouter
Expand Down Expand Up @@ -115,7 +125,13 @@ def get_all_runs() -> List[Run]:
Returns:
list of Run objects
"""
return format_runs(data_access_manager.runs.get_all_runs())
all_runs = data_access_manager.runs.get_all_runs()
if not all_runs:
return []
all_run_ids = [run.id for run in all_runs]
return format_runs(
all_runs, data_access_manager.runs.get_user_run_details_by_run_ids(all_run_ids)
)


def format_run_tracking_data(
Expand Down Expand Up @@ -256,8 +272,21 @@ class Subscription:
"""Subscription object to track runs added in real time"""

@strawberry.subscription
def run_added(self, run_id: ID) -> Run:
"""Subscription to add runs in real-time"""
async def runs_added(self) -> AsyncGenerator[List[Run], None]:
"""Subscription to new runs in real-time"""
while True:
new_runs = data_access_manager.runs.get_new_runs()
if new_runs:
data_access_manager.runs.last_run_id = new_runs[0].id
yield [
format_run(
run.id,
json.loads(run.blob),
data_access_manager.runs.get_user_run_details(run.id),
)
for run in new_runs
]
await asyncio.sleep(3) # pragma: no cover


@strawberry.type
Expand Down Expand Up @@ -348,8 +377,10 @@ def update_run_details(self, run_id: ID, run_input: RunInput) -> Response:
return UpdateRunDetailsSuccess(updated_run)


schema = strawberry.Schema(query=Query, mutation=Mutation)
schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription)

router = APIRouter()

router.add_route("/graphql", GraphQL(schema))
graphql_app = GraphQL(schema)
router.add_route("/graphql", graphql_app)
router.add_websocket_route("/graphql", graphql_app)
18 changes: 16 additions & 2 deletions package/kedro_viz/data_access/repositories/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def func(self: "RunsRepository", *method_args, **method_kwargs):

class RunsRepository:
_db_session_class: Optional[sessionmaker]
last_run_id: Optional[str]

def __init__(self, db_session_class: Optional[sessionmaker] = None):
self._db_session_class = db_session_class
self.last_run_id = None

def set_db_session(self, db_session_class: sessionmaker):
"""Sqlite db connection session"""
Expand All @@ -42,12 +44,15 @@ def add_run(self, run: RunModel):

@check_db_session
def get_all_runs(self) -> Optional[Iterable[RunModel]]:
return (
all_runs = (
self._db_session_class() # type: ignore
.query(RunModel)
.order_by(RunModel.id.desc())
.all()
)
if all_runs:
self.last_run_id = all_runs[0].id
return all_runs

@check_db_session
def get_run_by_id(self, run_id: str) -> Optional[RunModel]:
Expand All @@ -71,6 +76,16 @@ def get_user_run_details(self, run_id: str) -> Optional[UserRunDetailsModel]:
.first()
)

@check_db_session
def get_new_runs(self) -> Optional[Iterable[RunModel]]:
query = self._db_session_class().query(RunModel) # type: ignore

if self.last_run_id:
# TODO: change this query to use timestamp once we have added that column
query = query.filter(RunModel.id > self.last_run_id)

return query.order_by(RunModel.id.desc()).all()

@check_db_session
def get_user_run_details_by_run_ids(
self, run_ids: List[str]
Expand Down Expand Up @@ -102,5 +117,4 @@ def create_or_update_user_run_details(
user_run_details.title = title
user_run_details.bookmark = bookmark
user_run_details.notes = notes

return user_run_details
1 change: 1 addition & 0 deletions package/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ addopts=--verbose -ra
--ignore package/tests
--no-cov-on-fail
-ra
--asyncio-mode auto
1 change: 1 addition & 0 deletions package/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mypy~=0.930
psutil==5.6.6 # same as Kedro for now
pylint~=2.8.2
pytest~=6.2.0
pytest-asyncio~=0.17.2
pytest-cov~=2.11.1
pytest-mock~=3.6.1
sqlalchemy-stubs~=0.4
Expand Down
102 changes: 102 additions & 0 deletions package/tests/test_api/test_graphql/test_subscriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json

from kedro_viz.api.graphql import schema
from kedro_viz.models.experiments_tracking import RunModel


class TestRunsAddedSubscription:
async def test_runs_added_subscription_with_no_existing_run(
self, data_access_manager_with_no_run
):
query = """
subscription {
runsAdded {
id
bookmark
gitSha
timestamp
title
}
}
"""

# start subscription
subscription = await schema.subscribe(query)

example_new_run_id = "test_id"
run = RunModel(
id=example_new_run_id,
blob=json.dumps(
{
"session_id": example_new_run_id,
"cli": {"command_path": "kedro run"},
}
),
)
data_access_manager_with_no_run.runs.add_run(run)

# assert subscription result
async for result in subscription:
assert not result.errors
assert result.data == {
"runsAdded": [
{
"id": example_new_run_id,
"bookmark": False,
"gitSha": None,
"timestamp": example_new_run_id,
"title": example_new_run_id,
}
]
}
break

async def test_runs_added_subscription_with_existing_runs(
self, data_access_manager_with_runs
):
query = """
subscription {
runsAdded {
id
bookmark
gitSha
timestamp
title
}
}
"""
all_runs = data_access_manager_with_runs.runs.get_all_runs()
assert all_runs

# start subscription
subscription = await schema.subscribe(query)

# add a new run
example_new_run_id = "new_run"
run = RunModel(
id=example_new_run_id,
blob=json.dumps(
{
"session_id": example_new_run_id,
"cli": {"command_path": "kedro run"},
}
),
)
data_access_manager_with_runs.runs.add_run(run)

# assert subscription result
async for result in subscription:
assert not result.errors
assert result.data == {
"runsAdded": [
{
"id": example_new_run_id,
"bookmark": False,
"gitSha": None,
"timestamp": example_new_run_id,
"title": example_new_run_id,
}
]
}
assert data_access_manager_with_runs.runs.last_run_id == example_new_run_id
break
38 changes: 35 additions & 3 deletions src/apollo/config.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,47 @@
import fetch from 'cross-fetch';
import { ApolloClient, InMemoryCache, createHttpLink } from '@apollo/client';
import {
ApolloClient,
InMemoryCache,
createHttpLink,
split,
} from '@apollo/client';
import { getMainDefinition } from '@apollo/client/utilities';
import { WebSocketLink } from '@apollo/client/link/ws';

const link = createHttpLink({
const wsHost =
process.env.NODE_ENV === 'development'
? 'localhost:4142'
: window.location.host;

const wsLink = new WebSocketLink({
uri: `ws://${wsHost}/graphql`,
options: {
reconnect: true,
},
});

const httpLink = createHttpLink({
// our graphql endpoint, normally here: http://localhost:4141/graphql
uri: '/graphql',
fetch,
});

const splitLink = split(
({ query }) => {
const definition = getMainDefinition(query);

return (
definition.kind === 'OperationDefinition' &&
definition.operation === 'subscription'
);
},
wsLink,
httpLink
);

export const client = new ApolloClient({
connectToDevTools: true,
link,
link: splitLink,
cache: new InMemoryCache(),
defaultOptions: {
query: {
Expand Down
2 changes: 1 addition & 1 deletion src/apollo/queries.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ export const GET_RUNS = gql`
export const GET_RUN_METADATA = gql`
query getRunMetadata($runIds: [ID!]!) {
runMetadata(runIds: $runIds) {
id
author
bookmark
gitBranch
gitSha
id
notes
runCommand
timestamp
Expand Down
4 changes: 4 additions & 0 deletions src/apollo/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ input RunInput {
notes: String = null
}

type Subscription {
runsAdded: [Run!]!
}

type TrackingDataset {
datasetName: String
datasetType: String
Expand Down
Loading

0 comments on commit d02270d

Please sign in to comment.