Skip to content

Commit

Permalink
Merge pull request #42 from VectorInstitute/bugfix_safety
Browse files Browse the repository at this point in the history
Improve safety to handle cases when data is missing, still bug
  • Loading branch information
amrit110 authored Oct 2, 2024
2 parents f9698dd + 3812dda commit d8f2205
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 42 deletions.
66 changes: 62 additions & 4 deletions backend/api/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel, Field, validator

from api.models.config import EndpointConfig
from api.models.constants import METRIC_DISPLAY_NAMES
from api.models.utils import deep_convert_numpy


Expand Down Expand Up @@ -486,15 +487,72 @@ class ModelSafety(BaseModel):
----------
metrics : List[Metric]
A list of individual metrics and their current status.
last_evaluated : datetime
A timestamp of when the model was last evaluated.
last_evaluated : Optional[str]
A timestamp of when the model was evaluated in ISO 8601 format.
is_recently_evaluated : bool
Whether the model was recently evaluated.
overall_status : str
The overall status of the model ('No warnings' or 'Warning').
"""

metrics: List[Metric]
last_evaluated: str = Field(..., description="ISO 8601 formatted date string")
last_evaluated: Optional[str] = Field(
..., description="ISO 8601 formatted date string"
)
is_recently_evaluated: bool
overall_status: str


def _default_data(
metric_name: str,
) -> Tuple[ModelFacts, EvaluationCriterion, EvaluationFrequency]:
"""Create default data for the model.
Parameters
----------
metric_name : str
The name of the metric to create default data for.
Returns
-------
Tuple[ModelFacts, EvaluationCriterion, EvaluationFrequency]
The default data for the model.
"""
model_facts = ModelFacts(
name="This is the name of the model",
version="This is the version of the model",
type="This is the type of the model",
intended_use="This is the intended use of the model",
target_population="This is the target population of the model",
input_data=["This is the input data of the model"],
output_data="This is the output data of the model",
summary="This is the summary of the model",
mechanism_of_action="This is the mechanism of action of the model",
validation_and_performance=ValidationAndPerformance(
internal_validation="This is the internal validation of the model",
external_validation="This is the external validation of the model",
performance_in_subgroups=[
"This is the performance in subgroups of the model"
],
),
uses_and_directions=["This is the uses and directions of the model"],
warnings=["This is the warnings of the model"],
other_information=OtherInformation(
approval_date="This is the approval date of the model",
license="This is the license of the model",
contact_information="This is the contact information of the model",
publication_link="This is the publication link of the model",
),
)
evaluation_criterion = EvaluationCriterion(
metric_name=metric_name,
display_name=METRIC_DISPLAY_NAMES[metric_name],
operator=ComparisonOperator.GREATER_THAN_OR_EQUAL_TO,
threshold=0.5,
)
evaluation_frequency = EvaluationFrequency(
value=7,
unit="days",
)
return model_facts, evaluation_criterion, evaluation_frequency
8 changes: 8 additions & 0 deletions backend/api/models/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EvaluationResult,
ModelBasicInfo,
ModelData,
_default_data,
)
from api.models.db import DATA_DIR, load_model_data, save_model_data
from api.models.utils import deep_convert_numpy
Expand Down Expand Up @@ -194,11 +195,18 @@ def add_model(self, model_info: ModelBasicInfo) -> str:
The unique ID of the newly added model.
"""
model_id = str(uuid.uuid4())
metric_name = f"{self.config.metrics[0].type}_{self.config.metrics[0].name}"
model_facts, evaluation_criterion, evaluation_frequency = _default_data(
metric_name
)
model_data = ModelData(
id=model_id,
endpoint_name=self.name,
basic_info=model_info,
endpoints=[self.name],
facts=model_facts,
evaluation_criterion=evaluation_criterion,
evaluation_frequency=evaluation_frequency,
)
save_model_data(model_id, model_data)
self.data.models.append(model_id)
Expand Down
8 changes: 6 additions & 2 deletions backend/api/models/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ async def get_model_safety(model_id: str) -> ModelSafety: # noqa: PLR0912
"collection", []
)
if not collection:
raise ValueError("No metrics found in performance data")
return ModelSafety(
metrics=[],
last_evaluated=None,
overall_status="Not evaluated",
is_recently_evaluated=False,
)
current_date: datetime = datetime.now()
last_evaluated = datetime.fromisoformat(collection[0]["timestamps"][-1])

Expand Down Expand Up @@ -131,7 +136,6 @@ async def get_model_safety(model_id: str) -> ModelSafety: # noqa: PLR0912
passed=status == "met",
)
)

all_criteria_met = all(metric.status == "met" for metric in metrics)
evaluation_frequency = model_data.evaluation_frequency or EvaluationFrequency(
value=30, unit="days"
Expand Down
2 changes: 1 addition & 1 deletion backend/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from api.models.data import ModelFacts


BASE_URL = "http://localhost:8000" # Adjust this to your API's base URL
BASE_URL = "http://localhost:8001" # Adjust this to your API's base URL


def api_request(method: str, endpoint: str, data: Dict = None) -> Dict:
Expand Down
21 changes: 3 additions & 18 deletions frontend/src/app/context/model.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,10 @@
import React, { createContext, useState, useContext, ReactNode, useCallback, useMemo, useEffect } from 'react';
import { ModelFacts } from '../types/facts';
import { Criterion, EvaluationFrequency } from '../types/evaluation-criteria';
import { ModelData } from '../types/model';
import { useAuth } from './auth';
import { debounce, DebouncedFunc } from 'lodash';

interface ModelBasicInfo {
name: string;
version: string;
}

interface ModelData {
id: string;
endpoints: string[];
basic_info: ModelBasicInfo;
facts: ModelFacts | null;
evaluation_criteria: Criterion[];
evaluation_frequency: EvaluationFrequency | null;
overall_status: string;
}

interface ModelContextType {
models: ModelData[];
fetchModels: () => Promise<void>;
Expand Down Expand Up @@ -77,7 +63,7 @@ export const ModelProvider: React.FC<{ children: ReactNode }> = ({ children }) =
id,
...modelInfo,
overall_status: safetyData.overall_status
};
} as ModelData;
}));
setModels(modelArray);
} catch (error) {
Expand All @@ -101,13 +87,12 @@ export const ModelProvider: React.FC<{ children: ReactNode }> = ({ children }) =
}

const [modelData, safetyData, factsData] = await Promise.all([
apiRequest<any>(`/api/models/${id}`),
apiRequest<ModelData>(`/api/models/${id}`),
apiRequest<{ overall_status: string }>(`/api/model/${id}/safety`),
apiRequest<ModelFacts>(`/api/models/${id}/facts`)
]);

const newModel: ModelData = {
id,
...modelData,
overall_status: safetyData.overall_status,
facts: factsData,
Expand Down
57 changes: 42 additions & 15 deletions frontend/src/app/model/[id]/tabs/safety.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import {
HStack,
Skeleton,
SkeletonText,
Alert,
AlertIcon,
} from '@chakra-ui/react';
import { CheckCircleIcon, WarningIcon, InfoIcon } from '@chakra-ui/icons';
import { formatDistanceToNow, parseISO } from 'date-fns';
Expand Down Expand Up @@ -59,11 +61,28 @@ const ModelSafetyTab: React.FC<ModelSafetyTabProps> = ({ modelId }) => {
fetchModelSafety();
}, [fetchModelSafety]);

if (error) {
return (
<Alert status="error">
<AlertIcon />
{error}
</Alert>
);
}

const isNotEvaluated = safetyData?.overall_status === 'Not evaluated';

return (
<Box p={4}>
<Heading as="h2" size="xl" mb={6} color={textColor}>
Model Safety Dashboard
</Heading>
{isNotEvaluated && !isLoading && (
<Alert status="info" mb={6}>
<AlertIcon />
This model has not been evaluated yet.
</Alert>
)}
<SimpleGrid columns={{ base: 1, lg: 2 }} spacing={8}>
<VStack spacing={4}>
<SafetyStatusCard
Expand Down Expand Up @@ -108,10 +127,12 @@ interface SafetyStatusCardProps extends CardProps {
const SafetyStatusCard: React.FC<SafetyStatusCardProps> = ({ overallStatus, cardBgColor, borderColor, textColor, isLoading }) => {
const tooltipLabel = overallStatus === 'No warnings'
? "All safety criteria have been met and the model has been recently evaluated."
: overallStatus === 'Not evaluated'
? "The model has not been evaluated yet."
: "One or more safety criteria have not been met or the model needs re-evaluation. Check the Safety Evaluation Checklist for details.";

const statusColor = overallStatus === 'No warnings' ? 'green' : 'red';
const StatusIcon = overallStatus === 'No warnings' ? CheckCircleIcon : WarningIcon;
const statusColor = overallStatus === 'No warnings' ? 'green' : overallStatus === 'Not evaluated' ? 'gray' : 'red';
const StatusIcon = overallStatus === 'No warnings' ? CheckCircleIcon : overallStatus === 'Not evaluated' ? InfoIcon : WarningIcon;

return (
<Box bg={cardBgColor} p={6} borderRadius="lg" boxShadow="md" borderColor={borderColor} borderWidth={1} width="100%">
Expand All @@ -138,6 +159,8 @@ interface LastEvaluatedCardProps extends CardProps {
const LastEvaluatedCard: React.FC<LastEvaluatedCardProps> = ({ lastEvaluated, isRecentlyEvaluated, cardBgColor, borderColor, textColor, isLoading }) => {
const tooltipLabel = isRecentlyEvaluated
? "The model has been evaluated within the specified evaluation frequency threshold."
: lastEvaluated === null
? "The model has not been evaluated yet."
: "The model has not been evaluated recently and may need re-evaluation.";

return (
Expand All @@ -146,16 +169,16 @@ const LastEvaluatedCard: React.FC<LastEvaluatedCardProps> = ({ lastEvaluated, is
<Stat>
<StatLabel>Time since last evaluation</StatLabel>
<Skeleton isLoaded={!isLoading}>
<StatNumber>{lastEvaluated ? formatDistanceToNow(lastEvaluated) + ' ago' : 'N/A'}</StatNumber>
<StatNumber>{lastEvaluated ? formatDistanceToNow(lastEvaluated) + ' ago' : 'Not evaluated'}</StatNumber>
</Skeleton>
<Skeleton isLoaded={!isLoading}>
{isRecentlyEvaluated !== undefined && (
<Tooltip label={tooltipLabel} placement="top" hasArrow>
<Flex align="center" mt={2} cursor="help">
<Badge colorScheme={isRecentlyEvaluated ? 'green' : 'red'} mr={2}>
{isRecentlyEvaluated ? 'Recent' : 'Needs Re-evaluation'}
<Badge colorScheme={lastEvaluated === null ? 'gray' : isRecentlyEvaluated ? 'green' : 'red'} mr={2}>
{lastEvaluated === null ? 'Not Evaluated' : isRecentlyEvaluated ? 'Recent' : 'Needs Re-evaluation'}
</Badge>
{isRecentlyEvaluated ? <CheckCircleIcon color="green.500" /> : <WarningIcon color="red.500" />}
{lastEvaluated === null ? <InfoIcon color="gray.500" /> : isRecentlyEvaluated ? <CheckCircleIcon color="green.500" /> : <WarningIcon color="red.500" />}
</Flex>
</Tooltip>
)}
Expand All @@ -172,15 +195,19 @@ interface SafetyMetricsCardProps extends CardProps {
const SafetyMetricsCard: React.FC<SafetyMetricsCardProps> = ({ metrics, cardBgColor, borderColor, textColor, isLoading }) => (
<Box bg={cardBgColor} p={6} borderRadius="lg" boxShadow="md" borderColor={borderColor} borderWidth={1}>
<Heading as="h3" size="md" mb={4} color={textColor}>Evaluation Checklist</Heading>
<List spacing={3}>
{isLoading ? (
Array.from({ length: 3 }).map((_, index) => (
{isLoading ? (
<List spacing={3}>
{Array.from({ length: 3 }).map((_, index) => (
<ListItem key={index}>
<SkeletonText noOfLines={1} spacing="4" />
</ListItem>
))
) : (
metrics.map((metric, index) => (
))}
</List>
) : metrics.length === 0 ? (
<Text color={textColor}>No metrics available. The model has not been evaluated yet.</Text>
) : (
<List spacing={3}>
{metrics.map((metric, index) => (
<ListItem key={index}>
<HStack spacing={2} align="center">
<ListIcon
Expand All @@ -199,9 +226,9 @@ const SafetyMetricsCard: React.FC<SafetyMetricsCardProps> = ({ metrics, cardBgCo
</Tooltip>
</HStack>
</ListItem>
))
)}
</List>
))}
</List>
)}
</Box>
);

Expand Down
22 changes: 22 additions & 0 deletions frontend/src/app/types/model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { z } from 'zod';

import { ModelFacts } from './facts';
import { CriterionSchema, EvaluationFrequencySchema } from './evaluation-criteria';

export const ModelBasicInfoSchema = z.object({
name: z.string(),
version: z.string(),
});

export const ModelDataSchema = z.object({
id: z.string(),
endpoints: z.array(z.string()),
basic_info: ModelBasicInfoSchema,
facts: z.custom<ModelFacts>().nullable(),
evaluation_criteria: z.array(CriterionSchema),
evaluation_frequency: EvaluationFrequencySchema.nullable(),
overall_status: z.string(),
});

export type ModelBasicInfo = z.infer<typeof ModelBasicInfoSchema>;
export type ModelData = z.infer<typeof ModelDataSchema>;
4 changes: 2 additions & 2 deletions frontend/src/app/types/safety.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { z } from 'zod';
import { MetricSchema } from './performance-metrics';

export const ModelSafetySchema = z.object({
metrics: z.array(MetricSchema),
last_evaluated: z.string(),
metrics: z.array(MetricSchema).optional().default([]),
last_evaluated: z.string().nullable(),
is_recently_evaluated: z.boolean(),
overall_status: z.string()
});
Expand Down

0 comments on commit d8f2205

Please sign in to comment.