From 0ad27215fa5c67e1372e51cd1ba51646e426ca53 Mon Sep 17 00:00:00 2001 From: FilippoTrotter <129587442+FilippoTrotter@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:32:08 +0200 Subject: [PATCH] feat: accept stdin input for piping (#25) Add support for Stdin input ## Summary by CodeRabbit - **New Features** - Added support for specifying start and end labels in file processing. - Introduced handling of standard input with label-based content processing. - **Enhancements** - Improved flexibility in selecting programming languages for comment handling. - Standardized comment characters for various programming languages. - **Tests** - Added test cases for processing labeled content and standard input. --------- Co-authored-by: Filippo Trotter Co-authored-by: Puria Nafisi Azizi --- cmd/tgcom/main.go | 50 +++++---- internal/file/file.go | 197 +++++++++++++++++++++++----------- internal/file/file_test.go | 49 ++++++++- internal/language/language.go | 54 +++++----- 4 files changed, 240 insertions(+), 110 deletions(-) diff --git a/cmd/tgcom/main.go b/cmd/tgcom/main.go index 7cef0d1..59994e3 100644 --- a/cmd/tgcom/main.go +++ b/cmd/tgcom/main.go @@ -1,9 +1,11 @@ //go:build !vcs + package main import ( "flag" "fmt" + "os" "strings" "github.com/dyne/tgcom/internal/comment" @@ -17,7 +19,7 @@ func main() { endLabelFlag := flag.String("end-label", "", "The end label for a section") actionFlag := flag.String("action", "", "can be comment, uncomment or toggle") dryRunFlag := flag.Bool("dry-run", false, "Perform a dry run without modifying the files") - + lang := flag.String("language", "", "Specify the programming language") flag.Parse() filename := *fileFlag @@ -26,6 +28,10 @@ func main() { endLabel := *endLabelFlag action := *actionFlag dryRun := *dryRunFlag + langStr := *lang + info, _ := os.Stdin.Stat() + isStdin := (info.Mode() & os.ModeCharDevice) == 0 + var modFunc func(string, string) string switch action { @@ -44,12 +50,6 @@ func main() { return } - if filename == "" { - fmt.Println("Please provide a filename to process.") - flag.PrintDefaults() - return - } - if startLabel == "" && endLabel != "" { fmt.Println("Error: 'startLabel' is required when 'endLabel' is provided.") return @@ -57,29 +57,39 @@ func main() { fmt.Println("Error: 'endLabel' is required when 'startLabel' is provided.") return } - if startLabel != "" && lineStr != "" { fmt.Println("Error: Specify either line number/range OR label, not both.") return } - if strings.Contains(filename, ",") { - if err := file.ProcessMultipleFiles(filename, dryRun); err != nil { + if isStdin { + if err := file.ProcessStdin(lineStr, langStr, startLabel, endLabel, modFunc, dryRun); err != nil { fmt.Println("Error processing files:", err) } } else { - if strings.Contains(filename, ":") { - parts := strings.Split(filename, ":") - if len(parts) != 2 { - fmt.Println("Invalid syntax format. Use ':'") - return - } - filename = parts[0] - lineStr = parts[1] + if filename == "" { + fmt.Println("Please provide a filename to process.") + flag.PrintDefaults() + return } - if err := file.ProcessSingleFile(filename, lineStr, startLabel, endLabel, modFunc, dryRun); err != nil { - fmt.Println("Error processing file:", err) + if strings.Contains(filename, ",") { + if err := file.ProcessMultipleFiles(filename, dryRun); err != nil { + fmt.Println("Error processing files:", err) + } + } else { + if strings.Contains(filename, ":") { + parts := strings.Split(filename, ":") + if len(parts) != 2 { + fmt.Println("Invalid syntax format. Use ':'") + return + } + filename = parts[0] + lineStr = parts[1] + } + if err := file.ProcessSingleFile(filename, lineStr, startLabel, endLabel, modFunc, dryRun); err != nil { + fmt.Println("Error processing file:", err) + } } } } diff --git a/internal/file/file.go b/internal/file/file.go index 22d0c40..988b17a 100644 --- a/internal/file/file.go +++ b/internal/file/file.go @@ -88,6 +88,64 @@ func shouldProcessLine(currentLine int, lineNum [2]int, startLabel, endLabel str return lineNum[0] <= currentLine && currentLine <= lineNum[1] } + +// processes input from stdin. +func ProcessStdin(lineStr, startLabel, endLabel, lang string, modFunc func(string, string) string, dryRun bool) error { + var lineNum [2]int + if startLabel == "" && endLabel == "" { + startLine, endLine, err := extractLines(lineStr) + if err != nil { + return err + } + lineNum = [2]int{startLine, endLine} + } + commentChars, err := selectCommentChars("", lang) + if err != nil { + return err + } + + input := os.Stdin + + if dryRun { + return printChanges(input, lineNum, startLabel, endLabel, commentChars, modFunc) + } + + // Process input from stdin directly + scanner := bufio.NewScanner(input) + currentLine := 1 + inSection := false + for scanner.Scan() { + lineContent := scanner.Text() + + // Determine if we are processing based on line numbers or labels + if startLabel != "" && endLabel != "" { + if strings.Contains(lineContent, startLabel) { + inSection = true + } + if inSection { + lineContent = modFunc(lineContent, commentChars) + } + if strings.Contains(lineContent, endLabel) { + inSection = false + } + } else { + if lineNum[0] <= currentLine && currentLine <= lineNum[1] { + lineContent = modFunc(lineContent, commentChars) + } + } + + // Print the modified line to stdout + fmt.Println(lineContent) + currentLine++ + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + func writeChanges(inputFile *os.File, outputFile *os.File, lineNum [2]int, startLabel, endLabel string, commentChars string, modFunc func(string, string) string) error { scanner := bufio.NewScanner(inputFile) writer := bufio.NewWriter(outputFile) @@ -136,7 +194,7 @@ func printChanges(inputFile *os.File, lineNum [2]int, startLabel, endLabel, comm for scanner.Scan() { lineContent := scanner.Text() - + if strings.Contains(lineContent, endLabel) { inSection = false } @@ -190,7 +248,9 @@ func restoreBackup(filename, backupFilename string) { // ProcessSingleFile processes a single file specified by filename. func ProcessSingleFile(filename string, lineStr, startLabel, endLabel string, modFunc func(string, string) string, dryRun bool) error { - commentChars, err := selectCommentChars(filename) + + commentChars, err := selectCommentChars(filename, "") + if err != nil { return fmt.Errorf("error selecting comment characters: %w", err) } @@ -235,7 +295,7 @@ func processFileWithLines(fileInfo string, dryRun bool) error { } lineNum := [2]int{startLine, endLine} - commentChars, err := selectCommentChars(file) + commentChars, err := selectCommentChars(file, "") if err != nil { return fmt.Errorf("error selecting comment characters: %w", err) } @@ -267,64 +327,75 @@ func extractLines(lineStr string) (startLine, endLine int, err error) { return } -func selectCommentChars(filename string) (string, error) { - extension := filepath.Ext(filename) - var commentChars string - switch extension { - case ".go": - commentChars = language.CommentChars["GoLang"] - case ".js": - commentChars = language.CommentChars["JS"] - case ".sh", ".bash": - commentChars = language.CommentChars["Bash"] - case ".cpp", ".cc", ".h", ".c": - commentChars = language.CommentChars["C++/C"] - case ".java": - commentChars = language.CommentChars["Java"] - case ".py": - commentChars = language.CommentChars["Python"] - case ".rb": - commentChars = language.CommentChars["Ruby"] - case ".pl": - commentChars = language.CommentChars["Perl"] - case ".php": - commentChars = language.CommentChars["PHP"] - case ".swift": - commentChars = language.CommentChars["swift"] - case ".kt", ".kts": - commentChars = language.CommentChars["Kotlin"] - case ".R": - commentChars = language.CommentChars["R"] - case ".hs": - commentChars = language.CommentChars["Haskell"] - case ".sql": - commentChars = language.CommentChars["SQL"] - case ".rs": - commentChars = language.CommentChars["Rust"] - case ".scala": - commentChars = language.CommentChars["Scala"] - case ".dart": - commentChars = language.CommentChars["Dart"] - case ".mm": - commentChars = language.CommentChars["Objective-C"] - case ".m": - commentChars = language.CommentChars["MATLAB"] - case ".lua": - commentChars = language.CommentChars["Lua"] - case ".erl": - commentChars = language.CommentChars["Erlang"] - case ".ex", ".exs": - commentChars = language.CommentChars["Elixir"] - case ".ts": - commentChars = language.CommentChars["TS"] - case ".vhdl", ".vhd": - commentChars = language.CommentChars["VHDL"] - case ".v", ".sv": - commentChars = language.CommentChars["Verilog"] - case ".html": - commentChars = language.CommentChars["HTML"] - default: - return "", fmt.Errorf("unsupported file extension: %s", extension) - } - return commentChars, nil +func selectCommentChars(filename, lang string) (string, error) { + if lang != "" { + lang = strings.ToLower(lang) + commentChars, ok := language.CommentChars[lang] + if !ok { + return "", fmt.Errorf("unsupported language: %s", lang) + } + return commentChars, nil + } + + if filename != "" { + extension := filepath.Ext(filename) + switch extension { + case ".go": + return language.CommentChars["golang"], nil + case ".js": + return language.CommentChars["js"], nil + case ".sh", ".bash": + return language.CommentChars["bash"], nil + case ".cpp", ".cc", ".h", ".c": + return language.CommentChars["C"], nil + case ".java": + return language.CommentChars["java"], nil + case ".py": + return language.CommentChars["python"], nil + case ".rb": + return language.CommentChars["ruby"], nil + case ".pl": + return language.CommentChars["perl"], nil + case ".php": + return language.CommentChars["php"], nil + case ".swift": + return language.CommentChars["swift"], nil + case ".kt", ".kts": + return language.CommentChars["kotlin"], nil + case ".R": + return language.CommentChars["r"], nil + case ".hs": + return language.CommentChars["haskell"], nil + case ".sql": + return language.CommentChars["sql"], nil + case ".rs": + return language.CommentChars["rust"], nil + case ".scala": + return language.CommentChars["scala"], nil + case ".dart": + return language.CommentChars["dart"], nil + case ".mm": + return language.CommentChars["objective-c"], nil + case ".m": + return language.CommentChars["matlab"], nil + case ".lua": + return language.CommentChars["lua"], nil + case ".html": + return language.CommentChars["html"], nil + case ".erl": + return language.CommentChars["erlang"], nil + case ".ex", ".exs": + return language.CommentChars["elixir"], nil + case ".ts": + return language.CommentChars["ts"], nil + case ".vhdl", ".vhd": + return language.CommentChars["vhdl"], nil + case ".v", ".sv": + return language.CommentChars["verilog"], nil + default: + return "", fmt.Errorf("unsupported file extension: %s", extension) + } + } + + return "", fmt.Errorf("language not specified and no filename provided") } diff --git a/internal/file/file_test.go b/internal/file/file_test.go index 27e4bdd..8353013 100644 --- a/internal/file/file_test.go +++ b/internal/file/file_test.go @@ -131,6 +131,53 @@ func TestProcessFile(t *testing.T) { }) } +func TestProcessStdin(t *testing.T) { + input := "line 1\nline 2\nline 3\nline 4\n" + modFunc := func(line string, commentChars string) string { + return commentChars + " " + line + } + + // Create pipes for stdin and stdout redirection + rStdin, wStdin, _ := os.Pipe() + rStdout, wStdout, _ := os.Pipe() + + // Write the mock input to the writer end of the stdin pipe + go func() { + defer wStdin.Close() + _, _ = wStdin.Write([]byte(input)) + }() + + // Save original stdin and stdout + oldStdin := os.Stdin + defer func() { os.Stdin = oldStdin }() + oldStdout := os.Stdout + defer func() { os.Stdout = oldStdout }() + + // Redirect stdin and stdout + os.Stdin = rStdin + os.Stdout = wStdout + + err := ProcessStdin("1-3", "", "", "go", modFunc, false) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + wStdout.Close() + + // Read the captured output + var buf bytes.Buffer + io.Copy(&buf, rStdout) + rStdout.Close() + + // Check the output + got := buf.String() + expected := "// line 1\n// line 2\n// line 3\nline 4\n" + if got != expected { + t.Errorf("expected %q, got %q", expected, got) + } + +} + func TestProcessSingleFile(t *testing.T) { // Setup test files with content tests := []struct { @@ -271,7 +318,7 @@ func TestSelectCommentChars(t *testing.T) { } for _, tt := range tests { - commentChars, err := selectCommentChars(tt.filename) + commentChars, err := selectCommentChars(tt.filename, "") if (err != nil) != tt.shouldErr { t.Errorf("selectCommentChars(%s) error = %v", tt.filename, err) } diff --git a/internal/language/language.go b/internal/language/language.go index ad22adc..71a762e 100644 --- a/internal/language/language.go +++ b/internal/language/language.go @@ -2,30 +2,32 @@ package language // GoCommentChars holds the comment characters for Go language. var CommentChars = map[string]string{ - "GoLang": "//", - "JS": "//", - "Bash": "#", - "C++/C": "//", - "Java": "//", - "Python": "#", - "Ruby": "#", - "Perl": "#", - "PHP": "//", - "Swift": "//", - "Kotlin": "//", - "R": "#", - "Haskell": "--", - "SQL": "--", - "Rust": "//", - "Scala": "//", - "Dart": "//", - "Objective-C": "//", - "MATLAB": "%", - "Lua": "--", - "Erlang": "%", - "Elixir": "#", - "TS": "//", - "VHDL": "--", - "Verilog": "//", - "HTML": "", + "golang": "//", + "go": "//", + "js": "//", + "bash": "#", + "c": "//", + "c++": "//", + "java": "//", + "python": "#", + "ruby": "#", + "perl": "#", + "php": "//", + "swift": "//", + "kotlin": "//", + "r": "#", + "haskell": "--", + "sql": "--", + "rust": "//", + "scala": "//", + "dart": "//", + "objective-c": "//", + "matlab": "%", + "lua": "--", + "erlang": "%", + "elixir": "#", + "ts": "//", + "vhdl": "--", + "verilog": "//", + "html": "", }