Skip to content

Commit

Permalink
Segment out delimiting functionality from FileReference
Browse files Browse the repository at this point in the history
  • Loading branch information
jmpaz committed Apr 12, 2024
1 parent 071e211 commit d957d35
Showing 1 changed file with 45 additions and 57 deletions.
102 changes: 45 additions & 57 deletions contextualize/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,61 +3,26 @@

class FileReference:
def __init__(
self,
path: str,
range: tuple = None,
format="md",
label="relative",
clean_contents=False,
self, path, range=None, format="md", label="relative", clean_contents=False
):
self.range = range
self.path = path
self.format = format
self.label = label
self.clean_contents = clean_contents

# prepare the reference string
self.output = self.get_contents()

def get_contents(self):
try:
with open(self.path, "r") as file:
contents = file.read()
self.file_content = contents
contents = self.process(contents)
return contents
except UnicodeDecodeError:
print(f"Skipping unreadable file: {self.path}")
return ""
except FileNotFoundError:
print(f"File not found: {self.path}")
return ""
except Exception as e:
print(f"Error occurred while reading file: {self.path}")
print(f"Error details: {str(e)}")
print(f"Error reading file {self.path}: {str(e)}")
return ""

def process(self, contents):
if self.clean_contents:
contents = self.clean(contents)
if self.range:
contents = self.extract_range(contents, self.range)
if self.format == "md":
max_backticks = self.count_max_backticks(contents)
contents = self.delineate(
contents, self.format, self.get_label(), max_backticks
)
else:
contents = self.delineate(contents, self.format, self.get_label())
return contents

def extract_range(self, contents, range):
start, end = range
lines = contents.split("\n")
return "\n".join(lines[start - 1 : end])

def clean(self, contents):
return contents.replace(" ", "\t")
return process_text(
contents, self.clean_contents, self.range, self.format, self.get_label()
)

def get_label(self):
if self.label == "relative":
Expand All @@ -69,23 +34,46 @@ def get_label(self):
else:
return ""

def count_max_backticks(self, contents):
max_backticks = 0
lines = contents.split("\n")
for line in lines:
if line.startswith("`"):
max_backticks = max(max_backticks, len(line) - len(line.lstrip("`")))
return max_backticks

def delineate(self, contents, format, label, max_backticks=0):
if format == "md":
backticks_str = "`" * (max_backticks + 2) if max_backticks >= 3 else "```"
return f"{backticks_str}{label}\n{contents}\n{backticks_str}"
elif format == "xml":
return f"<file path='{label}'>\n{contents}\n</file>"
else:
return contents


def concat_refs(file_references: list):
return "\n\n".join(ref.output for ref in file_references)


def _clean(text):
return text.replace(" ", "\t")


def _extract_range(text, range):
"""Extracts lines from contents based on range tuple."""
start, end = range
lines = text.split("\n")
return "\n".join(lines[start - 1 : end])


def _count_max_backticks(text):
max_backticks = 0
lines = text.split("\n")
for line in lines:
if line.startswith("`"):
max_backticks = max(max_backticks, len(line) - len(line.lstrip("`")))
return max_backticks


def _delimit(text, format, label, max_backticks=0):
if format == "md":
backticks_str = "`" * (max_backticks + 2) if max_backticks >= 3 else "```"
return f"{backticks_str}{label}\n{text}\n{backticks_str}"
elif format == "xml":
return f"<file path='{label}'>\n{text}\n</file>"
else:
return text


def process_text(text, clean=False, range=None, format="md", label=""):
if clean:
text = _clean(text)
if range:
text = _extract_range(text, range)
max_backticks = _count_max_backticks(text)
contents = _delimit(text, format, label, max_backticks)
return contents

0 comments on commit d957d35

Please sign in to comment.