diff --git a/microservices-connector/cmd/router/main.go b/microservices-connector/cmd/router/main.go index c49e8ca6..719b774a 100644 --- a/microservices-connector/cmd/router/main.go +++ b/microservices-connector/cmd/router/main.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "io" + "mime/multipart" "net/http" "os" @@ -48,6 +49,7 @@ const ( ChunkSize = 1024 ServiceURL = "serviceUrl" ServiceNode = "node" + DataPrep = "DataPrep" ) type EnsembleStepOutput struct { @@ -477,6 +479,143 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) { } } +func mcDataHandler(w http.ResponseWriter, r *http.Request) { + var isDataHandled bool + serviceName := r.Header.Get("SERVICE_NAME") + defaultNode := mcGraph.Spec.Nodes[defaultNodeName] + for i := range defaultNode.Steps { + step := &defaultNode.Steps[i] + if DataPrep == step.StepName { + if serviceName != "" && serviceName != step.InternalService.ServiceName { + continue + } + log.Info("Starting execution of step", "stepName", step.StepName) + serviceURL := getServiceURLByStepTarget(step, mcGraph.Namespace) + log.Info("ServiceURL is", "serviceURL", serviceURL) + // Parse the multipart form in the request + // err := r.ParseMultipartForm(64 << 20) // 64 MB is the default used by ParseMultipartForm + + // Set no limit on multipart form size + err := r.ParseMultipartForm(0) + if err != nil { + http.Error(w, "Failed to parse multipart form", http.StatusBadRequest) + return + } + // Create a buffer to hold the new form data + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + // Copy all form fields from the original request to the new request + for key, values := range r.MultipartForm.Value { + for _, value := range values { + err := writer.WriteField(key, value) + if err != nil { + handleMultipartError(writer, err) + http.Error(w, "Failed to write form field", http.StatusInternalServerError) + return + } + } + } + // Copy all files from the original request to the new request + for key, fileHeaders := range r.MultipartForm.File { + for _, fileHeader := range fileHeaders { + file, err := fileHeader.Open() + if err != nil { + handleMultipartError(writer, err) + http.Error(w, "Failed to open file", http.StatusInternalServerError) + return + } + defer func() { + if err := file.Close(); err != nil { + log.Error(err, "error closing file") + } + }() + part, err := writer.CreateFormFile(key, fileHeader.Filename) + if err != nil { + handleMultipartError(writer, err) + http.Error(w, "Failed to create form file", http.StatusInternalServerError) + return + } + _, err = io.Copy(part, file) + if err != nil { + handleMultipartError(writer, err) + http.Error(w, "Failed to copy file", http.StatusInternalServerError) + return + } + } + } + + err = writer.Close() + if err != nil { + http.Error(w, "Failed to close writer", http.StatusInternalServerError) + return + } + + req, err := http.NewRequest(r.Method, serviceURL, &buf) + if err != nil { + http.Error(w, "Failed to create new request", http.StatusInternalServerError) + return + } + // Copy headers from the original request to the new request + for key, values := range r.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + http.Error(w, "Failed to send request to backend", http.StatusInternalServerError) + return + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Error(err, "error closing response body stream") + } + }() + // Copy the response headers from the backend service to the original client + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(resp.StatusCode) + // Copy the response body from the backend service to the original client + _, err = io.Copy(w, resp.Body) + if err != nil { + log.Error(err, "failed to copy response body") + } + isDataHandled = true + } + } + + if !isDataHandled { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + if _, err := w.Write([]byte("\n Message: None dataprep endpoint is available! \n")); err != nil { + log.Info("Message: ", "failed to write mcDataHandler response") + } + } +} + +func handleMultipartError(writer *multipart.Writer, err error) { + // In case of an error, close the writer to clean up + werr := writer.Close() + if werr != nil { + log.Error(werr, "Error during close writer") + return + } + // Handle the error as needed, such as logging or returning an error response + log.Error(err, "Error during multipart creation") +} + +func initializeRoutes() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/", mcGraphHandler) + mux.HandleFunc("/dataprep", mcDataHandler) + return mux +} + func main() { flag.Parse() logf.SetLogger(zap.New()) @@ -488,13 +627,13 @@ func main() { os.Exit(1) } - http.HandleFunc("/", mcGraphHandler) + mcRouter := initializeRoutes() server := &http.Server{ // specify the address and port Addr: ":8080", - // specify your HTTP handler - Handler: http.HandlerFunc(mcGraphHandler), + // specify the HTTP routers + Handler: mcRouter, // set the maximum duration for reading the entire request, including the body ReadTimeout: time.Minute, // set the maximum duration before timing out writes of the response diff --git a/microservices-connector/cmd/router/main_test.go b/microservices-connector/cmd/router/main_test.go index f70d0f5d..25afa912 100644 --- a/microservices-connector/cmd/router/main_test.go +++ b/microservices-connector/cmd/router/main_test.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "io" + "mime/multipart" "net/http" "net/http/httptest" "os" @@ -821,3 +822,95 @@ func TestMcGraphHandler_Timeout(t *testing.T) { t.Errorf("expected error message '%s'; got '%s'", expectedErrorMessage, string(body)) } } + +func TestMcDataHandler(t *testing.T) { + // Start a local HTTP server + service1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + _, err := io.ReadAll(req.Body) + if err != nil { + return + } + response := map[string]interface{}{"predictions": "1"} + responseBytes, _ := json.Marshal(response) + _, err = rw.Write(responseBytes) + if err != nil { + return + } + })) + service1Url, err := apis.ParseURL(service1.URL) + if err != nil { + t.Fatalf("Failed to parse model url") + } + defer service1.Close() + + // Create a buffer to store the multipart form data + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + + // Add form fields + err = writer.WriteField("key", "value") + if err != nil { + t.Fatalf("failed to write form field: %v", err) + } + + // Add a file field + part, err := writer.CreateFormFile("file", "filename.txt") + if err != nil { + t.Fatalf("failed to create form file: %v", err) + } + _, err = part.Write([]byte("file content")) + if err != nil { + t.Fatalf("failed to write to form file: %v", err) + } + + // Close the writer to finalize the multipart form + err = writer.Close() + if err != nil { + t.Fatalf("failed to close writer: %v", err) + } + + // Create a new HTTP request with the multipart form data + req := httptest.NewRequest(http.MethodPost, "/dataprep", &buf) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + // Create a ResponseRecorder to capture the response + rr := httptest.NewRecorder() + + // Mock the mcGraph data + mcGraph = &mcv1alpha3.GMConnector{ + Spec: mcv1alpha3.GMConnectorSpec{ + Nodes: map[string]mcv1alpha3.Router{ + "root": { + Steps: []mcv1alpha3.Step{ + { + StepName: "DataPrep", + ServiceURL: service1Url.String(), + Executor: mcv1alpha3.Executor{ + InternalService: mcv1alpha3.GMCTarget{ + NameSpace: "default", + ServiceName: "example-service", + }, + }, + }, + }, + }, + }, + }, + } + + // Call the mcDataHandler function + mcDataHandler(rr, req) + + // Check the response status code + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + // Check the response body if needed + expected := "{\"predictions\":\"1\"}" + if strings.TrimSpace(rr.Body.String()) != expected { + t.Errorf("handler returned unexpected body: got %v want %v", + rr.Body.String(), expected) + } +}