diff --git a/src/Redux/actionCreators/SDKActions.js b/src/Redux/actionCreators/SDKActions.js index 57fccb38..e769a5d5 100644 --- a/src/Redux/actionCreators/SDKActions.js +++ b/src/Redux/actionCreators/SDKActions.js @@ -1,5 +1,5 @@ import { isEmpty } from "lodash"; -import { initSdk } from "../../utility/sdk"; +import { getWeb3Address, initSdk } from "../../utility/sdk"; export const SET_SDK = "SET_SDK"; export const SET_SERVICE_CLIENT = "SET_SERVICE_CLIENT"; @@ -24,10 +24,13 @@ export const initializingSdk = (ethereumWalletAddress) => async (dispatch) => { const initializeServiceClient = (organizationId, serviceId) => async (dispatch) => { const sdk = await dispatch(getSdk()); - return await sdk.createServiceClient(organizationId, serviceId); + const serviceClient = await sdk.createServiceClient(organizationId, serviceId); + // dispatch(updateServiceClient(serviceClient)) + return serviceClient; }; export const getSdk = () => async (dispatch, getState) => { + await getWeb3Address(); let sdk = getState().sdkReducer.sdk; if (!isEmpty(sdk)) { return sdk; diff --git a/src/Redux/actionCreators/ServiceTrainingActions.js b/src/Redux/actionCreators/ServiceTrainingActions.js index 1bc8082c..9644a735 100644 --- a/src/Redux/actionCreators/ServiceTrainingActions.js +++ b/src/Redux/actionCreators/ServiceTrainingActions.js @@ -2,7 +2,8 @@ import axios from "axios"; import { LoaderContent } from "../../utility/constants/LoaderContent"; import { startAppLoader, stopAppLoader } from "./LoaderActions"; import { getServiceClient } from "./SDKActions"; - +import { updateMetamaskWallet } from "./UserActions"; +import { modelStatus } from "../reducers/ServiceTrainingReducer"; export const SET_MODEL_DETAILS = "SET_MODEL_DETAILS"; export const SET_MODELS_LIST = "SET_MODELS_LIST"; export const RESET_MODEL_DETAILS = "RESET_MODEL_DETAILS"; @@ -24,9 +25,10 @@ export const resetModelList = () => (dispatch) => { dispatch({ type: RESET_MODEL_LIST }); }; -export const createModel = (organizationId, serviceId, address, newModelParams) => async (dispatch) => { +export const createModel = (organizationId, serviceId, newModelParams) => async (dispatch) => { try { dispatch(startAppLoader(LoaderContent.CREATE_TRAINING_MODEL)); + const address = await dispatch(updateMetamaskWallet()); const serviceName = getServiceNameFromTrainingMethod(newModelParams?.trainingMethod); const params = { modelName: newModelParams?.trainingModelName, @@ -40,7 +42,6 @@ export const createModel = (organizationId, serviceId, address, newModelParams) const serviceClient = await dispatch(getServiceClient(organizationId, serviceId)); const createdModelData = await serviceClient.createModel(address, params); - console.log("createdModelData: ", createdModelData); dispatch(setCurrentModelDetails(createdModelData)); await dispatch(getTrainingModels(organizationId, serviceId, address)); @@ -51,10 +52,11 @@ export const createModel = (organizationId, serviceId, address, newModelParams) } }; -export const updateModel = (organizationId, serviceId, address, updateModelParams) => async (dispatch, getState) => { +export const updateModel = (organizationId, serviceId, updateModelParams) => async (dispatch, getState) => { const currentModelDetails = getState().serviceTrainingReducer.currentModel; try { dispatch(startAppLoader(LoaderContent.UPDATE_MODEL)); + const address = await dispatch(updateMetamaskWallet()); const params = { modelId: currentModelDetails.modelId, address, @@ -80,26 +82,26 @@ export const updateModel = (organizationId, serviceId, address, updateModelParam } }; -export const deleteModel = - (organizationId, serviceId, modelId, methodName, serviceName, address) => async (dispatch) => { - try { - dispatch(startAppLoader(LoaderContent.DELETE_MODEL)); - const params = { - modelId, - method: methodName, - address, - name: serviceName, - }; - const serviceClient = await dispatch(getServiceClient(organizationId, serviceId)); - await serviceClient.deleteModel(params); - await dispatch(getTrainingModels(organizationId, serviceId, address)); - dispatch(resetCurrentModelDetails()); - } catch (error) { - // TODO - } finally { - dispatch(stopAppLoader()); - } - }; +export const deleteModel = (organizationId, serviceId, modelId, methodName, serviceName) => async (dispatch) => { + try { + dispatch(startAppLoader(LoaderContent.DELETE_MODEL)); + const address = await dispatch(updateMetamaskWallet()); + const params = { + modelId, + method: methodName, + address, + name: serviceName, + }; + const serviceClient = await dispatch(getServiceClient(organizationId, serviceId)); + await serviceClient.deleteModel(params); + await dispatch(getTrainingModels(organizationId, serviceId, address)); + dispatch(resetCurrentModelDetails()); + } catch (error) { + // TODO + } finally { + dispatch(stopAppLoader()); + } +}; // export const getServiceName = () => (getState) => { // // const { serviceDetailsReducer, serviceTrainingReducer } = getState(); @@ -127,7 +129,7 @@ export const getTrainingModelStatus = address, }; const numberModelStatus = await serviceClient.getModelStatus(params); - return modelStatus[numberModelStatus]; + return modelStatusByNumber[numberModelStatus]; } catch (err) { // TODO } finally { @@ -148,9 +150,14 @@ export const getTrainingModels = (organizationId, serviceId, address) => async ( }; const response = await serviceClient.getExistingModel(params); + console.log("get models response: ", response); let modelsList = await Promise.all( response.map(async (model) => { + if (model.status !== modelStatus.IN_PROGRESS && model.status !== modelStatus.CREATED) { + return model; + } + const getModelStatusParams = { organizationId, serviceId, @@ -160,8 +167,8 @@ export const getTrainingModels = (organizationId, serviceId, address) => async ( address, }; - const modelStatus = await dispatch(getTrainingModelStatus(getModelStatusParams)); - return { ...model, status: modelStatus }; + const newModelStatus = await dispatch(getTrainingModelStatus(getModelStatusParams)); + return { ...model, status: newModelStatus }; }) ); @@ -174,7 +181,7 @@ export const getTrainingModels = (organizationId, serviceId, address) => async ( } }; -const modelStatus = { +const modelStatusByNumber = { 0: "CREATED", 1: "IN_PROGRESS", 2: "ERRORED", diff --git a/src/Redux/actionCreators/UserActions.js b/src/Redux/actionCreators/UserActions.js index eafb1195..fd342528 100644 --- a/src/Redux/actionCreators/UserActions.js +++ b/src/Redux/actionCreators/UserActions.js @@ -564,13 +564,16 @@ export const registerWallet = (address, type) => async (dispatch) => { export const updateMetamaskWallet = () => async (dispatch, getState) => { const sdk = await dispatch(sdkActions.getSdk()); const address = await sdk.account.getAddress(); - if (getState().userReducer.wallet.value === address) { - return; + + if (getState().userReducer.wallet?.address === address) { + return address; } + const availableUserWallets = await dispatch(fetchAvailableUserWallets()); const addressAlreadyRegistered = availableUserWallets.some( (wallet) => wallet.address.toLowerCase() === address.toLowerCase() ); + if (!addressAlreadyRegistered) { await dispatch(registerWallet(address, walletTypes.METAMASK)); } diff --git a/src/assets/thirdPartyServices/snet/GLMT/index.js b/src/assets/thirdPartyServices/snet/GLMT/index.js index effee646..29b95148 100644 --- a/src/assets/thirdPartyServices/snet/GLMT/index.js +++ b/src/assets/thirdPartyServices/snet/GLMT/index.js @@ -12,17 +12,23 @@ import { useStyles } from "./styles"; import { withStyles } from "@mui/styles"; import meta from "./meta.json"; -const SELECT_MODELS_FROM_LIST_MESSAGE = "You can try training!"; // TODO fill message +const SELECT_MODELS_FROM_LIST_MESSAGE = "You can try training! Use the corresponding tab above."; // TODO fill message class T_GLM extends React.Component { constructor(props) { super(props); + const defaultModel = { value: "null", label: "Default" }; // Backend is expecting "null" (String) as value in defaullt case + const { config } = meta; const { selectedModelId, modelsIds } = props; const modelsList = modelsIds?.length ? modelsIds : []; + if (modelsList?.length) { + modelsList.push(defaultModel); + } + this.submitAction = this.submitAction.bind(this); this.changeSlider = this.changeSlider.bind(this); this.toggleSettings = this.toggleSettings.bind(this); @@ -31,6 +37,7 @@ class T_GLM extends React.Component { this.state = { config: config, + model_id: !selectedModelId ? defaultModel.value : selectedModelId, modelsList: modelsList, users_guide: "https://github.com/iktina/Generative-Language-Models", // TODO replace link code_repo: "https://github.com/iktina/Generative-Language-Models", // TODO replace link @@ -48,7 +55,7 @@ class T_GLM extends React.Component { isSettingsOpen: false, }; - console.clear(); + // console.clear(); } canBeInvoked() { @@ -113,13 +120,8 @@ class T_GLM extends React.Component { const methodDescriptor = VITSTrainingService["inference"]; const request = new methodDescriptor.requestType(); - request.setData(this.constructRequest()); // TODO - - if (!model_id) { - request.setModelId("null"); - } else { - request.setModelId(model_id); - } + request.setData(this.constructRequest()); + request.setModelId(model_id); const props = { request, @@ -127,7 +129,7 @@ class T_GLM extends React.Component { this.setState({ response: { status: "success", - output: JSON.parse(message.getResult()), // TODO + output: JSON.parse(message.getResult()), }, }); }, @@ -144,7 +146,7 @@ class T_GLM extends React.Component { return ( - {!this.state.model_id ? ( + {!this.props.selectedModelId && !this.props.modelsIds?.length ? (

{SELECT_MODELS_FROM_LIST_MESSAGE}

) : ( @@ -402,21 +404,53 @@ class T_GLM extends React.Component { throw new Error("Cannot read data from response"); } + const currentModel = this.state.modelsList.find((model) => model?.value === this.state.model_id); + const modelLabel = !currentModel?.label ? "Default" : currentModel.label; + return ( - - - - - +
+

+ {`Input (model: ${modelLabel}):`} +

+ + + + + +

+ Output: +

+ + - +
); } diff --git a/src/assets/thirdPartyServices/snet/GLMT/t_glm_pb_service.js b/src/assets/thirdPartyServices/snet/GLMT/t_glm_pb_service.js index 6bacc414..ef743eac 100644 --- a/src/assets/thirdPartyServices/snet/GLMT/t_glm_pb_service.js +++ b/src/assets/thirdPartyServices/snet/GLMT/t_glm_pb_service.js @@ -10,15 +10,6 @@ var VITSTrainingService = (function () { return VITSTrainingService; })(); -VITSTrainingService.start_training = { - methodName: "start_training", - service: VITSTrainingService, - requestStream: false, - responseStream: false, - requestType: t_glm_pb.TrainingRequest, - responseType: t_glm_pb.TrainingResponse, -}; - VITSTrainingService.inference = { methodName: "inference", service: VITSTrainingService, @@ -35,37 +26,6 @@ function VITSTrainingServiceClient(serviceHost, options) { this.options = options || {}; } -VITSTrainingServiceClient.prototype.start_training = function start_training(requestMessage, metadata, callback) { - if (arguments.length === 2) { - callback = arguments[1]; - } - var client = grpc.unary(VITSTrainingService.start_training, { - request: requestMessage, - host: this.serviceHost, - metadata: metadata, - transport: this.options.transport, - debug: this.options.debug, - onEnd: function (response) { - if (callback) { - if (response.status !== grpc.Code.OK) { - var err = new Error(response.statusMessage); - err.code = response.status; - err.metadata = response.trailers; - callback(err, null); - } else { - callback(null, response.message); - } - } - }, - }); - return { - cancel: function () { - callback = null; - client.close(); - }, - }; -}; - VITSTrainingServiceClient.prototype.inference = function inference(requestMessage, metadata, callback) { if (arguments.length === 2) { callback = arguments[1]; diff --git a/src/components/ServiceDetails/ExistingModel/ModelDetails/index.js b/src/components/ServiceDetails/ExistingModel/ModelDetails/index.js index a1dea6a5..6b378cb7 100644 --- a/src/components/ServiceDetails/ExistingModel/ModelDetails/index.js +++ b/src/components/ServiceDetails/ExistingModel/ModelDetails/index.js @@ -19,8 +19,9 @@ import { } from "../../../../Redux/actionCreators/ServiceTrainingActions"; import { useLocation, useNavigate, useParams } from "react-router-dom"; import { modelStatus } from "../../../../Redux/reducers/ServiceTrainingReducer"; +import { updateMetamaskWallet } from "../../../../Redux/actionCreators/UserActions"; -const ModelDetails = ({ classes, openEditModel, model, address }) => { +const ModelDetails = ({ classes, openEditModel, model }) => { const dispatch = useDispatch(); const navigate = useNavigate(); const location = useLocation(); @@ -35,7 +36,7 @@ const ModelDetails = ({ classes, openEditModel, model, address }) => { const isInferenceAvailable = model.status === modelStatus.COMPLETED; const handleDeleteModel = async () => { - await dispatch(deleteModel(orgId, serviceId, model.modelId, model.methodName, model.serviceName, address)); + await dispatch(deleteModel(orgId, serviceId, model.modelId, model.methodName, model.serviceName)); setOpen(false); }; @@ -50,6 +51,7 @@ const ModelDetails = ({ classes, openEditModel, model, address }) => { }; const handleGetModelStatus = async () => { + const address = await dispatch(updateMetamaskWallet()); const getModelStatusParams = { organizationId: orgId, serviceId, diff --git a/src/components/ServiceDetails/TrainingModels/CreateModel/ModelInfo/index.js b/src/components/ServiceDetails/TrainingModels/CreateModel/ModelInfo/index.js index e5de5a53..b7cc4bfa 100644 --- a/src/components/ServiceDetails/TrainingModels/CreateModel/ModelInfo/index.js +++ b/src/components/ServiceDetails/TrainingModels/CreateModel/ModelInfo/index.js @@ -7,7 +7,7 @@ import StyledDropdown from "../../../../common/StyledDropdown"; import StyledTextField from "../../../../common/StyledTextField"; import StyledButton from "../../../../common/StyledButton"; -import { loaderActions, userActions } from "../../../../../Redux/actionCreators"; +import { loaderActions } from "../../../../../Redux/actionCreators"; import { createModel, deleteModel } from "../../../../../Redux/actionCreators/ServiceTrainingActions"; import { LoaderContent } from "../../../../../utility/constants/LoaderContent"; import { currentServiceDetails } from "../../../../../Redux/reducers/ServiceDetailsReducer"; @@ -49,7 +49,7 @@ const ModelInfo = ({ classes, cancelEditModel }) => { // try { // const address = await dispatch(userActions.updateMetamaskWallet()); - // await dispatch(updateModel(org_id, service_id, address, updateModelParams)); + // await dispatch(updateModel(org_id, service_id, updateModelParams)); // cancelEditModel(); // } catch (error) { // setAlert({ type: alertTypes.ERROR, message: "Unable to update model. Please try again" }); @@ -59,9 +59,8 @@ const ModelInfo = ({ classes, cancelEditModel }) => { // }; const onDelete = async () => { - const address = await dispatch(userActions.updateMetamaskWallet()); await dispatch( - deleteModel(org_id, service_id, currentModel.modelId, currentModel.methodName, currentModel.serviceName, address) + deleteModel(org_id, service_id, currentModel.modelId, currentModel.methodName, currentModel.serviceName) ); cancelEditModel(); }; @@ -69,7 +68,6 @@ const ModelInfo = ({ classes, cancelEditModel }) => { const onNext = async () => { try { dispatch(loaderActions.startAppLoader(LoaderContent.CONNECT_METAMASK)); - const address = await dispatch(userActions.updateMetamaskWallet()); const newModelParams = { trainingModelName, trainingMethod, @@ -78,7 +76,7 @@ const ModelInfo = ({ classes, cancelEditModel }) => { isRestrictAccessModel, dataLink: trainingDataLink, }; - await dispatch(createModel(org_id, service_id, address, newModelParams)); + await dispatch(createModel(org_id, service_id, newModelParams)); dispatch(loaderActions.stopAppLoader()); // handleNextClick(); cancelEditModel(); diff --git a/src/utility/sdk.js b/src/utility/sdk.js index c556da9b..d02696bb 100644 --- a/src/utility/sdk.js +++ b/src/utility/sdk.js @@ -8,7 +8,7 @@ import PaypalPaymentMgmtStrategy from "./PaypalPaymentMgmtStrategy"; import { ethereumMethods } from "./constants/EthereumUtils"; import { store } from "../"; import ProxyPaymentChannelManagementStrategy from "./ProxyPaymentChannelManagementStrategy"; -import { isUndefined } from "lodash"; +import { isEmpty, isUndefined } from "lodash"; const DEFAULT_GAS_PRICE = 4700000; const DEFAULT_GAS_LIMIT = 210000; @@ -189,25 +189,13 @@ const switchNetwork = async () => { }); }; -const updateSDK = async () => { - const isExpectedNetwork = await isUserAtExpectedEthereumNetwork(); - if (!isExpectedNetwork) { - await switchNetwork(); - } - - const config = { - networkId: await detectEthereumNetwork(), - web3Provider, - defaultGasPrice: DEFAULT_GAS_PRICE, - defaultGasLimit: DEFAULT_GAS_LIMIT, - }; - - sdk = await new SnetSDK(config); - await sdk.setupAccount(); +const clearSdk = () => { + sdk = undefined; }; const addListenersForWeb3 = () => { web3Provider.addListener(ON_ACCOUNT_CHANGE, (accounts) => { + clearSdk(); const event = new CustomEvent("snetMMAccountChanged", { detail: { address: accounts[0] } }); window.dispatchEvent(event); }); @@ -218,17 +206,47 @@ const addListenersForWeb3 = () => { }); }; +export const getWeb3Address = async () => { + defineWeb3Provider(); + const accounts = await web3Provider.request({ method: ethereumMethods.REQUEST_ACCOUNTS }); // TODO + return !isEmpty(accounts) ? accounts[0] : undefined; +}; + export const initSdk = async () => { if (sdk && !(sdk instanceof PaypalSDK)) { return Promise.resolve(sdk); } defineWeb3Provider(); - await web3Provider.request({ method: ethereumMethods.REQUEST_ACCOUNTS }); + await getWeb3Address(); // TODO addListenersForWeb3(); - await updateSDK(); + const isExpectedNetwork = await isUserAtExpectedEthereumNetwork(); + if (!isExpectedNetwork) { + await switchNetwork(); + } + + const config = { + networkId: await detectEthereumNetwork(), + web3Provider, + defaultGasPrice: DEFAULT_GAS_PRICE, + defaultGasLimit: DEFAULT_GAS_LIMIT, + }; + + sdk = await new SnetSDK(config); return Promise.resolve(sdk); }; +export const getSdkConfig = async () => { + defineWeb3Provider(); + const config = { + networkId: await detectEthereumNetwork(), + web3Provider, + defaultGasPrice: DEFAULT_GAS_PRICE, + defaultGasLimit: DEFAULT_GAS_LIMIT, + }; + + return config; +}; + const getMethodNames = (service) => { const ownProperties = Object.getOwnPropertyNames(service); return ownProperties.filter((property) => {