Skip to content

Commit

Permalink
LayoutPredictor blacklisted classes - No warnings (#24)
Browse files Browse the repository at this point in the history
* feat(LayoutPredictor): Introduce black-listed classes which are filtered out from the response.
Refactor the demo app to draw all recognized items.

* feat(AggProfiler): Enable memory monitoring (RSS) in AggProfiler. Clean up unused code.

* fix(Tag_Transformer): Fix the initialization of Tag_Transformer encoder.
Enable the AggProfiler in the unit tests.

---------

Signed-off-by: Nikos Livathinos <[email protected]>
  • Loading branch information
nikos-livathinos authored Sep 17, 2024
1 parent 1bbb753 commit d526b79
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 219 deletions.
26 changes: 13 additions & 13 deletions demo/demo_layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,26 @@ def demo(
draw = ImageDraw.Draw(out_img)

for i, pred in enumerate(preds):
scr = pred["confidence"]
lab = pred["label"]
score = pred["confidence"]
label = pred["label"]
box = [
round(pred["l"]),
round(pred["t"]),
round(pred["r"]),
round(pred["b"]),
]

if lab == "Table":
draw.rectangle(
box,
outline="red",
)
draw.text(
(box[0], box[1]),
text=str(lab),
fill="blue",
)
logger.info("Table %s: bbox=%s", i, box)
# Draw bbox and label
draw.rectangle(
box,
outline="red",
)
draw.text(
(box[0], box[1]),
text=str(label),
fill="blue",
)
logger.info("%s: [label|score|bbox] = ['%s' | %s | %s]", i, label, score, box)

save_fn = os.path.join(viz_dir, os.path.basename(img_fn))
out_img.save(save_fn)
Expand Down
58 changes: 33 additions & 25 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,6 @@
DEFAULT_NUM_THREADS = 4


# Classes:
CLASSES_MAP = {
0: "background",
1: "Caption",
2: "Footnote",
3: "Formula",
4: "List-item",
5: "Page-footer",
6: "Page-header",
7: "Picture",
8: "Section-header",
9: "Table",
10: "Text",
11: "Title",
12: "Document Index",
13: "Code",
14: "Checkbox-Selected",
15: "Checkbox-Unselected",
16: "Form",
17: "Key-Value Region",
}


class LayoutPredictor:
r"""
Document layout prediction using ONNX
Expand Down Expand Up @@ -69,6 +46,31 @@ def __init__(
------
FileNotFoundError when the model's ONNX file is missing
"""
# Initialize classes map:
self._classes_map = {
0: "background",
1: "Caption",
2: "Footnote",
3: "Formula",
4: "List-item",
5: "Page-footer",
6: "Page-header",
7: "Picture",
8: "Section-header",
9: "Table",
10: "Text",
11: "Title",
12: "Document Index",
13: "Code",
14: "Checkbox-Selected",
15: "Checkbox-Unselected",
16: "Form",
17: "Key-Value Region",
}

# Blacklisted classes
self._black_classes = set(["Form", "Key-Value Region"])

# Set basic params
self._threshold = 0.6 # Score threshold
self._image_size = 640
Expand Down Expand Up @@ -159,13 +161,19 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
)

# Yield output
for label, box, score in zip(labels[0], boxes[0], scores[0]):
for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
# Filter out blacklisted classes
label = self._classes_map[label_idx]
if label in self._black_classes:
continue

# Check against threshold
if score > self._threshold:
yield {
"l": box[0] / self._image_size * w,
"t": box[1] / self._image_size * h,
"r": box[2] / self._image_size * w,
"b": box[3] / self._image_size * h,
"label": CLASSES_MAP[label],
"label": label,
"confidence": score,
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ def __init__(
self._positional_encoding = PositionalEncoding(embed_dim)
self._td_encode = td_encode

encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_heads, dim_feedforward=dim_ff
)
self._encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_heads, dim_feedforward=dim_ff
),
num_layers=encoder_layers,
encoder_layer, num_layers=encoder_layers, enable_nested_tensor=False
)

self._decoder = TMTransformerDecoder(
Expand Down
13 changes: 12 additions & 1 deletion docling_ibm_models/tableformer/utils/app_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections import deque
from statistics import mean, median

from docling_ibm_models.tableformer.utils.mem_monitor import MemMonitor


class SingletonClass(type):
r"""
Expand Down Expand Up @@ -37,11 +39,13 @@ class Profiler:
def __init__(self):
self._section_dts = {} # section name -> sum(section intervals)
self._section_calls = {} # section name -> number of invocations
self._section_kB = {} # section name -> max kB of used heap
self._section_kB = {} # section name -> max kB of used heap (resident set size)

# section name -> beginning of the last interval
self._last_begin = {}

self._mem_monitor = MemMonitor()

def begin(self, section_name, enable=True):
r"""
Mark the beginning of an interval
Expand Down Expand Up @@ -83,13 +87,20 @@ def end(self, section_name, enable=True):
if section_name not in self._last_begin:
return False

# Get memory
kB = self._mem_monitor.get_memory()
if isinstance(kB, dict):
kB = kB["resident"]

dt = time.time() - self._last_begin[section_name]
if section_name not in self._section_dts:
self._section_dts[section_name] = dt
self._section_calls[section_name] = 1
self._section_kB[section_name] = kB
else:
self._section_dts[section_name] += dt
self._section_calls[section_name] += 1
self._section_kB[section_name] = max(kB, self._section_kB[section_name])

return True

Expand Down
175 changes: 175 additions & 0 deletions docling_ibm_models/tableformer/utils/mem_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import os
import platform
import re


class MemMonitor:
r"""
Memory monitor for Linux
It supports 2 approaches for extracting memory information:
- linux-native: It parse the `/proc` pseudo-files. It is available only for Linux
- psutil: Use the `psutil` library
## Linux-Native approach
The linux-native approach implements 2 methods to extract the memory fields:
1. The `get_memory()` method:
- It is very fast
- It parses the `/proc/<pid>/statm` pseudo-file
- It Contains the following fields:
size (1) total program size
(same as VmSize in /proc/[pid]/status)
resident (2) resident set size
(same as VmRSS in /proc/[pid]/status)
shared (3) number of resident shared pages (i.e., backed by a file)
(same as RssFile+RssShmem in /proc/[pid]/status)
text (4) text (code)
lib (5) library (unused since Linux 2.6; always 0)
data (6) data + stack
dt (7) dirty pages (unused since Linux 2.6; always 0)
2. The `get_memory_full()` method:
- It is slower to parse but contains more detailed information
- It uses regex to parse the `/proc/<pid>/status` pseudo-file
- It contains the following fields:
VmPeak: Peak virtual memory size.
VmSize: Virtual memory size.
VmLck: Locked memory size (see mlock(2)).
VmPin: Pinned memory size (since Linux 3.2). These are pages that can't be moved because
something needs to directly access physical memory.
VmHWM: Peak resident set size ("high water mark").
VmRSS: Resident set size. Note that the value here is the sum of RssAnon, RssFile, and
RssShmem.
RssAnon: Size of resident anonymous memory. (since Linux 4.5).
RssFile: Size of resident file mappings. (since Linux 4.5).
RssShmem: Size of resident shared memory (includes System V shared memory, mappings from
tmpfs(5), and shared anonymous mappings). (since Linux 4.5).
VmData, VmStk, VmExe: Size of data, stack, and text segments.
VmLib: Shared library code size.
VmPTE: Page table entries size (since Linux 2.6.10).
VmPMD: Size of second-level page tables (added in Linux 4.0; removed in Linux 4.15).
VmSwap: Swapped-out virtual memory size by anonymous private pages; shmem swap usage is
not included (since Linux 2.6.34).
## The psutil library
- Apparently the psutil library parses the `/proc/<pid>/statm`
- The memory_info() function returns the fields: rss, vms, shared, text, lib, data, dirty
## Field mappings
These are the fields returned by psutil memory_info() and their mapping in the /proc files:
(I put ? when I am not 100% about the mapping)
| psutil | /proc/$$/status | /proc/$$/statm |
|---------|--------------------|----------------|
| rss | VmRSS | resident |
| vms | VmSize | size |
| shared | RssFile + RssShmem | shared |
| text | VmExe ? | text |
| lib | RssShmem ? | lib |
| data | VmData + VmStk | data |
| dirty | VmSwap ? | dt |
"""

def __init__(self, enable=True):
self._enable = enable
self._pid = os.getpid()

# Create regex for each memory field of the /proc/status pseudo-file
self._status_fields = [
"VmPeak",
"VmSize",
"VmLck",
"VmPin",
"VmHWM",
"VmRSS",
"RssAnon",
"RssFile",
"RssShmem",
"VmData",
"VmStk",
"VmExe",
"VmLib",
"VmPTE",
"VmPMD",
"VmSwap",
]
self._status_regex = {}
for mem_field in self._status_fields:
regex_str = r"({}:)(\s+)(\d*)(.*)".format(mem_field)
self._status_regex[mem_field] = re.compile(regex_str)

def get_memory_full(self) -> dict:
r"""
- Parse /proc/<pid>status to get all memory info.
- The method returns a dict with the fields self._status_fields
- This method is SLOW. Unless you need the full memory info, better to use `get_memory`
The returned values are in kB
"""
if not self._enable:
return -2
if platform.system() != "Linux":
return -1
pid_fn = "/proc/{}/status".format(self._pid)

# Dict to collect all memory fields
memory = {}
with open(pid_fn, "r") as fn:
for ll in fn:
for mem_field in self._status_fields:
regex = self._status_regex[mem_field]
m = regex.match(ll)
if m is not None:
memory[mem_field] = int(m.group(3))
if len(memory) == len(self._status_fields):
break

return memory

def get_memory(self) -> dict:
r"""
- Parse /proc/<pid>statm to get the most important memory fields
- This is a fast implementation.
- The method returns a dict with the fields:
"size", "resident", "shared", "text", "lib", "data", "dt"
- Check the documentation at the top for a mapping across the various fields
The returned values are in kB
"""
if not self._enable:
return -2
if platform.system() != "Linux":
return -1
pid_fn = "/proc/{}/statm".format(self._pid)

# Dict to collect all memory fields
memory = {}
with open(pid_fn, "r") as fn:
ll = fn.read()
# The values are in pages
# Each page is 4096 bytes (4kB)
data = [int(x) << 2 for x in ll.split(" ")]
memory = {
"size": data[0],
"resident": data[1],
"shared": data[2],
"text": data[3],
"lib": data[4],
"data": data[5],
"dt": data[6],
}
return memory
Loading

0 comments on commit d526b79

Please sign in to comment.