Skip to content

Commit

Permalink
Merge pull request #948 from singnet/training
Browse files Browse the repository at this point in the history
Training
  • Loading branch information
MarinaFedy authored Nov 22, 2024
2 parents 53fe230 + cc4782f commit 78313e4
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 122 deletions.
7 changes: 5 additions & 2 deletions src/Redux/actionCreators/SDKActions.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { isEmpty } from "lodash";
import { initSdk } from "../../utility/sdk";
import { getWeb3Address, initSdk } from "../../utility/sdk";

export const SET_SDK = "SET_SDK";
export const SET_SERVICE_CLIENT = "SET_SERVICE_CLIENT";
Expand All @@ -24,10 +24,13 @@ export const initializingSdk = (ethereumWalletAddress) => async (dispatch) => {

const initializeServiceClient = (organizationId, serviceId) => async (dispatch) => {
const sdk = await dispatch(getSdk());
return await sdk.createServiceClient(organizationId, serviceId);
const serviceClient = await sdk.createServiceClient(organizationId, serviceId);
// dispatch(updateServiceClient(serviceClient))
return serviceClient;
};

export const getSdk = () => async (dispatch, getState) => {
await getWeb3Address();
let sdk = getState().sdkReducer.sdk;
if (!isEmpty(sdk)) {
return sdk;
Expand Down
63 changes: 35 additions & 28 deletions src/Redux/actionCreators/ServiceTrainingActions.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ import axios from "axios";
import { LoaderContent } from "../../utility/constants/LoaderContent";
import { startAppLoader, stopAppLoader } from "./LoaderActions";
import { getServiceClient } from "./SDKActions";

import { updateMetamaskWallet } from "./UserActions";
import { modelStatus } from "../reducers/ServiceTrainingReducer";
export const SET_MODEL_DETAILS = "SET_MODEL_DETAILS";
export const SET_MODELS_LIST = "SET_MODELS_LIST";
export const RESET_MODEL_DETAILS = "RESET_MODEL_DETAILS";
Expand All @@ -24,9 +25,10 @@ export const resetModelList = () => (dispatch) => {
dispatch({ type: RESET_MODEL_LIST });
};

export const createModel = (organizationId, serviceId, address, newModelParams) => async (dispatch) => {
export const createModel = (organizationId, serviceId, newModelParams) => async (dispatch) => {
try {
dispatch(startAppLoader(LoaderContent.CREATE_TRAINING_MODEL));
const address = await dispatch(updateMetamaskWallet());
const serviceName = getServiceNameFromTrainingMethod(newModelParams?.trainingMethod);
const params = {
modelName: newModelParams?.trainingModelName,
Expand All @@ -40,7 +42,6 @@ export const createModel = (organizationId, serviceId, address, newModelParams)

const serviceClient = await dispatch(getServiceClient(organizationId, serviceId));
const createdModelData = await serviceClient.createModel(address, params);
console.log("createdModelData: ", createdModelData);

dispatch(setCurrentModelDetails(createdModelData));
await dispatch(getTrainingModels(organizationId, serviceId, address));
Expand All @@ -51,10 +52,11 @@ export const createModel = (organizationId, serviceId, address, newModelParams)
}
};

export const updateModel = (organizationId, serviceId, address, updateModelParams) => async (dispatch, getState) => {
export const updateModel = (organizationId, serviceId, updateModelParams) => async (dispatch, getState) => {
const currentModelDetails = getState().serviceTrainingReducer.currentModel;
try {
dispatch(startAppLoader(LoaderContent.UPDATE_MODEL));
const address = await dispatch(updateMetamaskWallet());
const params = {
modelId: currentModelDetails.modelId,
address,
Expand All @@ -80,26 +82,26 @@ export const updateModel = (organizationId, serviceId, address, updateModelParam
}
};

export const deleteModel =
(organizationId, serviceId, modelId, methodName, serviceName, address) => async (dispatch) => {
try {
dispatch(startAppLoader(LoaderContent.DELETE_MODEL));
const params = {
modelId,
method: methodName,
address,
name: serviceName,
};
const serviceClient = await dispatch(getServiceClient(organizationId, serviceId));
await serviceClient.deleteModel(params);
await dispatch(getTrainingModels(organizationId, serviceId, address));
dispatch(resetCurrentModelDetails());
} catch (error) {
// TODO
} finally {
dispatch(stopAppLoader());
}
};
export const deleteModel = (organizationId, serviceId, modelId, methodName, serviceName) => async (dispatch) => {
try {
dispatch(startAppLoader(LoaderContent.DELETE_MODEL));
const address = await dispatch(updateMetamaskWallet());
const params = {
modelId,
method: methodName,
address,
name: serviceName,
};
const serviceClient = await dispatch(getServiceClient(organizationId, serviceId));
await serviceClient.deleteModel(params);
await dispatch(getTrainingModels(organizationId, serviceId, address));
dispatch(resetCurrentModelDetails());
} catch (error) {
// TODO
} finally {
dispatch(stopAppLoader());
}
};

// export const getServiceName = () => (getState) => {
// // const { serviceDetailsReducer, serviceTrainingReducer } = getState();
Expand Down Expand Up @@ -127,7 +129,7 @@ export const getTrainingModelStatus =
address,
};
const numberModelStatus = await serviceClient.getModelStatus(params);
return modelStatus[numberModelStatus];
return modelStatusByNumber[numberModelStatus];
} catch (err) {
// TODO
} finally {
Expand All @@ -148,9 +150,14 @@ export const getTrainingModels = (organizationId, serviceId, address) => async (
};

const response = await serviceClient.getExistingModel(params);
console.log("get models response: ", response);

let modelsList = await Promise.all(
response.map(async (model) => {
if (model.status !== modelStatus.IN_PROGRESS && model.status !== modelStatus.CREATED) {
return model;
}

const getModelStatusParams = {
organizationId,
serviceId,
Expand All @@ -160,8 +167,8 @@ export const getTrainingModels = (organizationId, serviceId, address) => async (
address,
};

const modelStatus = await dispatch(getTrainingModelStatus(getModelStatusParams));
return { ...model, status: modelStatus };
const newModelStatus = await dispatch(getTrainingModelStatus(getModelStatusParams));
return { ...model, status: newModelStatus };
})
);

Expand All @@ -174,7 +181,7 @@ export const getTrainingModels = (organizationId, serviceId, address) => async (
}
};

const modelStatus = {
const modelStatusByNumber = {
0: "CREATED",
1: "IN_PROGRESS",
2: "ERRORED",
Expand Down
7 changes: 5 additions & 2 deletions src/Redux/actionCreators/UserActions.js
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,16 @@ export const registerWallet = (address, type) => async (dispatch) => {
export const updateMetamaskWallet = () => async (dispatch, getState) => {
const sdk = await dispatch(sdkActions.getSdk());
const address = await sdk.account.getAddress();
if (getState().userReducer.wallet.value === address) {
return;

if (getState().userReducer.wallet?.address === address) {
return address;
}

const availableUserWallets = await dispatch(fetchAvailableUserWallets());
const addressAlreadyRegistered = availableUserWallets.some(
(wallet) => wallet.address.toLowerCase() === address.toLowerCase()
);

if (!addressAlreadyRegistered) {
await dispatch(registerWallet(address, walletTypes.METAMASK));
}
Expand Down
82 changes: 58 additions & 24 deletions src/assets/thirdPartyServices/snet/GLMT/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,23 @@ import { useStyles } from "./styles";
import { withStyles } from "@mui/styles";
import meta from "./meta.json";

const SELECT_MODELS_FROM_LIST_MESSAGE = "You can try training!"; // TODO fill message
const SELECT_MODELS_FROM_LIST_MESSAGE = "You can try training! Use the corresponding tab above."; // TODO fill message

class T_GLM extends React.Component {
constructor(props) {
super(props);

const defaultModel = { value: "null", label: "Default" }; // Backend is expecting "null" (String) as value in defaullt case

const { config } = meta;
const { selectedModelId, modelsIds } = props;

const modelsList = modelsIds?.length ? modelsIds : [];

if (modelsList?.length) {
modelsList.push(defaultModel);
}

this.submitAction = this.submitAction.bind(this);
this.changeSlider = this.changeSlider.bind(this);
this.toggleSettings = this.toggleSettings.bind(this);
Expand All @@ -31,6 +37,7 @@ class T_GLM extends React.Component {

this.state = {
config: config,
model_id: !selectedModelId ? defaultModel.value : selectedModelId,
modelsList: modelsList,
users_guide: "https://github.com/iktina/Generative-Language-Models", // TODO replace link
code_repo: "https://github.com/iktina/Generative-Language-Models", // TODO replace link
Expand All @@ -48,7 +55,7 @@ class T_GLM extends React.Component {
isSettingsOpen: false,
};

console.clear();
// console.clear();
}

canBeInvoked() {
Expand Down Expand Up @@ -113,21 +120,16 @@ class T_GLM extends React.Component {
const methodDescriptor = VITSTrainingService["inference"];
const request = new methodDescriptor.requestType();

request.setData(this.constructRequest()); // TODO

if (!model_id) {
request.setModelId("null");
} else {
request.setModelId(model_id);
}
request.setData(this.constructRequest());
request.setModelId(model_id);

const props = {
request,
onEnd: ({ message }) => {
this.setState({
response: {
status: "success",
output: JSON.parse(message.getResult()), // TODO
output: JSON.parse(message.getResult()),
},
});
},
Expand All @@ -144,7 +146,7 @@ class T_GLM extends React.Component {

return (
<Grid container direction="column" justify="center">
{!this.state.model_id ? (
{!this.props.selectedModelId && !this.props.modelsIds?.length ? (
<h3>{SELECT_MODELS_FROM_LIST_MESSAGE}</h3>
) : (
<Grid item xs={8} container spacing={2}>
Expand Down Expand Up @@ -402,21 +404,53 @@ class T_GLM extends React.Component {
throw new Error("Cannot read data from response");
}

const currentModel = this.state.modelsList.find((model) => model?.value === this.state.model_id);
const modelLabel = !currentModel?.label ? "Default" : currentModel.label;

return (
<React.Fragment>
<Grid container direction="column" justify="center">
<Grid item xs={12} style={{ textAlign: "left" }}>
<OutlinedTextArea
id="service_output"
name="service_output"
label="Result"
fullWidth={true}
value={output}
rows={8}
/>
</Grid>
<div>
<p
style={{
padding: "0 10px",
fontSize: "18px",
fontWeight: "600",
margin: "0",
}}
>
{`Input (model: ${modelLabel}):`}
</p>

<Grid item xs={12} container justify="center" style={{ textAlign: "center" }}>
<OutlinedTextArea
id="serviceInput"
name="serviceInput"
fullWidth={true}
value={this.state.request}
rows={5}
/>
</Grid>

<p
style={{
padding: "0 10px",
fontSize: "18px",
fontWeight: "600",
margin: "0",
}}
>
Output:
</p>
<Grid item xs={12} container justify="center" style={{ textAlign: "center" }}>
<OutlinedTextArea
id="service_output"
name="service_output"
label="Result"
fullWidth={true}
value={output}
rows={5}
/>
</Grid>
</React.Fragment>
</div>
);
}

Expand Down
40 changes: 0 additions & 40 deletions src/assets/thirdPartyServices/snet/GLMT/t_glm_pb_service.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@ var VITSTrainingService = (function () {
return VITSTrainingService;
})();

VITSTrainingService.start_training = {
methodName: "start_training",
service: VITSTrainingService,
requestStream: false,
responseStream: false,
requestType: t_glm_pb.TrainingRequest,
responseType: t_glm_pb.TrainingResponse,
};

VITSTrainingService.inference = {
methodName: "inference",
service: VITSTrainingService,
Expand All @@ -35,37 +26,6 @@ function VITSTrainingServiceClient(serviceHost, options) {
this.options = options || {};
}

VITSTrainingServiceClient.prototype.start_training = function start_training(requestMessage, metadata, callback) {
if (arguments.length === 2) {
callback = arguments[1];
}
var client = grpc.unary(VITSTrainingService.start_training, {
request: requestMessage,
host: this.serviceHost,
metadata: metadata,
transport: this.options.transport,
debug: this.options.debug,
onEnd: function (response) {
if (callback) {
if (response.status !== grpc.Code.OK) {
var err = new Error(response.statusMessage);
err.code = response.status;
err.metadata = response.trailers;
callback(err, null);
} else {
callback(null, response.message);
}
}
},
});
return {
cancel: function () {
callback = null;
client.close();
},
};
};

VITSTrainingServiceClient.prototype.inference = function inference(requestMessage, metadata, callback) {
if (arguments.length === 2) {
callback = arguments[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ import {
} from "../../../../Redux/actionCreators/ServiceTrainingActions";
import { useLocation, useNavigate, useParams } from "react-router-dom";
import { modelStatus } from "../../../../Redux/reducers/ServiceTrainingReducer";
import { updateMetamaskWallet } from "../../../../Redux/actionCreators/UserActions";

const ModelDetails = ({ classes, openEditModel, model, address }) => {
const ModelDetails = ({ classes, openEditModel, model }) => {
const dispatch = useDispatch();
const navigate = useNavigate();
const location = useLocation();
Expand All @@ -35,7 +36,7 @@ const ModelDetails = ({ classes, openEditModel, model, address }) => {
const isInferenceAvailable = model.status === modelStatus.COMPLETED;

const handleDeleteModel = async () => {
await dispatch(deleteModel(orgId, serviceId, model.modelId, model.methodName, model.serviceName, address));
await dispatch(deleteModel(orgId, serviceId, model.modelId, model.methodName, model.serviceName));
setOpen(false);
};

Expand All @@ -50,6 +51,7 @@ const ModelDetails = ({ classes, openEditModel, model, address }) => {
};

const handleGetModelStatus = async () => {
const address = await dispatch(updateMetamaskWallet());
const getModelStatusParams = {
organizationId: orgId,
serviceId,
Expand Down
Loading

0 comments on commit 78313e4

Please sign in to comment.