diff --git a/index.html b/index.html
index f76ebb9..5d14a16 100644
--- a/index.html
+++ b/index.html
@@ -10,7 +10,6 @@
body, html {
margin: 0;
- color: white;
}
diff --git a/src/clone-node.ts b/src/clone-node.ts
index 0a28c0f..f8ae30c 100644
--- a/src/clone-node.ts
+++ b/src/clone-node.ts
@@ -84,10 +84,6 @@ export async function cloneNode(
&& (isHTMLElementNode(node) || isSVGElementNode(node))) {
const computedStyle = ownerWindow.getComputedStyle(node)
- if (computedStyle.display === 'none') {
- return ownerDocument.createComment(node.tagName.toLowerCase())
- }
-
const cloned = await cloneElement(node, context)
const clonedStyle = cloned.style
diff --git a/src/copy-css-styles.ts b/src/copy-css-styles.ts
index cdfd1d0..553949e 100644
--- a/src/copy-css-styles.ts
+++ b/src/copy-css-styles.ts
@@ -17,8 +17,8 @@ export function copyCssStyles(
context: Context,
) {
const clonedStyle = cloned.style
- const defaultStyle = getDefaultStyle(node.nodeName, null, context)
- const diffStyle = getDiffStyle(computedStyle, defaultStyle, node)
+ const defaultStyle = getDefaultStyle(node, null, context)
+ const diffStyle = getDiffStyle(computedStyle, defaultStyle)
for (const [name, [value, priority]] of Object.entries(diffStyle)) {
if (ignoredStyle.includes(name)) continue
diff --git a/src/copy-pseudo-class.ts b/src/copy-pseudo-class.ts
index 7dff682..e34d8d2 100644
--- a/src/copy-pseudo-class.ts
+++ b/src/copy-pseudo-class.ts
@@ -43,7 +43,7 @@ export function copyPseudoClass(
if (!content || content === 'none') return
const klasses = [uuid()]
- const defaultStyle = getDefaultStyle(node.nodeName, pseudoClass, context)
+ const defaultStyle = getDefaultStyle(node, pseudoClass, context)
const cloneStyle = [
`content: '${ content.replace(/'|"/g, '') }';`,
]
diff --git a/src/get-default-style.ts b/src/get-default-style.ts
index 4790abe..76be746 100644
--- a/src/get-default-style.ts
+++ b/src/get-default-style.ts
@@ -1,11 +1,42 @@
-import { uuid } from './utils'
+import { isSVGElementNode, uuid } from './utils'
import type { Context } from './context'
-export function getDefaultStyle(nodeName: string, pseudoElement: string | null, context: Context) {
- nodeName = nodeName.toLowerCase()
+const ignoredStyles = [
+ 'width',
+ 'height',
+]
+
+const includedAttributes = [
+ 'stroke',
+ 'fill',
+]
+
+export function getDefaultStyle(
+ node: HTMLElement | SVGElement,
+ pseudoElement: string | null,
+ context: Context,
+) {
const { defaultComputedStyles, ownerDocument } = context
- const key = `${ nodeName }${ pseudoElement ?? '' }`
+
+ const nodeName = node.nodeName.toLowerCase()
+ const isSvgNode = isSVGElementNode(node) && nodeName !== 'svg'
+ const attributes = isSvgNode
+ ? includedAttributes
+ .map(name => [name, node.getAttribute(name)])
+ .filter(([, value]) => value !== null)
+ : []
+
+ const key = [
+ isSvgNode && 'svg',
+ nodeName,
+ attributes.map((name, value) => `${ name }=${ value }`).join(','),
+ pseudoElement,
+ ]
+ .filter(Boolean)
+ .join(':')
+
if (defaultComputedStyles.has(key)) return defaultComputedStyles.get(key)!
+
let sandbox = context.sandbox
if (!sandbox) {
if (ownerDocument) {
@@ -21,24 +52,35 @@ export function getDefaultStyle(nodeName: string, pseudoElement: string | null,
}
}
if (!sandbox) return {}
+
const sandboxWindow = sandbox.contentWindow
if (!sandboxWindow) return {}
const sandboxDocument = sandboxWindow.document
- const el = sandboxDocument.createElement(nodeName)
- sandboxDocument.body.appendChild(el)
- // Ensure that there is some content, so properties like margin are applied
+
+ let root: HTMLElement | SVGSVGElement
+ let el: Element
+ if (isSvgNode) {
+ root = sandboxDocument.createElementNS('http://www.w3.org/2000/svg', 'svg')
+ el = root.ownerDocument.createElementNS(root.namespaceURI, nodeName)
+ attributes.forEach(([name, value]) => {
+ el.setAttributeNS(null, name!, value!)
+ })
+ root.appendChild(el)
+ } else {
+ root = el = sandboxDocument.createElement(nodeName)
+ }
el.textContent = ' '
- const style = sandboxWindow.getComputedStyle(el, pseudoElement)
+ sandboxDocument.body.appendChild(root)
+ const computedStyle = sandboxWindow.getComputedStyle(el, pseudoElement)
const styles: Record = {}
- for (let i = style.length - 1; i >= 0; i--) {
- const name = style.item(i)
- if (name === 'width' || name === 'height') {
- styles[name] = 'auto'
- } else {
- styles[name] = style.getPropertyValue(name)
- }
+ for (let len = computedStyle.length, i = 0; i < len; i++) {
+ const name = computedStyle.item(i)
+ if (ignoredStyles.includes(name)) continue
+ styles[name] = computedStyle.getPropertyValue(name)
}
- sandboxDocument.body.removeChild(el)
+ sandboxDocument.body.removeChild(root)
+
defaultComputedStyles.set(key, styles)
+
return styles
}
diff --git a/src/get-diff-style.ts b/src/get-diff-style.ts
index bacdc70..0e0e6b3 100644
--- a/src/get-diff-style.ts
+++ b/src/get-diff-style.ts
@@ -6,7 +6,6 @@ const getPrefix = (name: string) => name
export function getDiffStyle(
style: CSSStyleDeclaration,
defaultStyle: Record,
- node?: HTMLElement | SVGElement,
) {
const diffStyle: Record = {}
const diffStylePrefixs: string[] = []
@@ -23,11 +22,7 @@ export function getDiffStyle(
prefixTree[prefix][name] = [value, priority]
}
- if (
- defaultStyle[name] === value
- && !priority
- && (node && !node.getAttribute(name))
- ) continue
+ if (defaultStyle[name] === value && !priority) continue
if (prefix) {
diffStylePrefixs.push(prefix)
diff --git a/test/fixtures/svg.color.html b/test/fixtures/svg.color.html
index 1c5ea4d..6ae7994 100644
--- a/test/fixtures/svg.color.html
+++ b/test/fixtures/svg.color.html
@@ -25,9 +25,9 @@
height="120"
requiredExtensions="http://www.w3.org/1999/xhtml"
>
-
+
diff --git a/test/fixtures/svg.symbol.html b/test/fixtures/svg.symbol.html
new file mode 100644
index 0000000..de27480
--- /dev/null
+++ b/test/fixtures/svg.symbol.html
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
+
diff --git a/test/fixtures/svg.symbol.png b/test/fixtures/svg.symbol.png
new file mode 100644
index 0000000..ed4a8f0
Binary files /dev/null and b/test/fixtures/svg.symbol.png differ