Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: SinkCache can handle iterative prompts #27907

Merged
merged 9 commits into from
Dec 8, 2023

Conversation

gante
Copy link
Member

@gante gante commented Dec 8, 2023

What does this PR do?

Fixes the case where SinkCache is used in a chat bot, receiving new prompts after giving an answer. Fix developed with @tomaarsen

Here's an example of a script that works after this PR:

from transformers import AutoTokenizer, SinkCache, AutoModelForCausalLM, TextStreamer
import torch
from datasets import load_dataset

# Loading the model & tokenizer
model_id = "HuggingFaceH4/zephyr-7b-beta"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Loading the prompts to simulate user interactions
prompt_dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
prompts = [prompt for prompts in prompt_dataset["prompt"] for prompt in prompts]

# Prepare generation settings
cache = SinkCache(window_length=1024, num_sink_tokens=4)
streamer = TextStreamer(tokenizer)

input_ids = torch.tensor([], device=model.device, dtype=torch.int)
for prompt in prompts:
    # Tokenize the prompt with the correct chat template
    chat = [{"role": "user", "content": prompt}]
    input_ids = torch.cat((input_ids, tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(model.device)), dim=1)
    # input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(model.device)

    # Perform the generation
    gen_out = model.generate(input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True, streamer=streamer)

    # input_ids = torch.cat((input_ids, gen_out), dim=1)
    input_ids = gen_out

    # If desired, decode the output from this prompt
    decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

@gante gante marked this pull request as ready for review December 8, 2023 13:46
@@ -209,8 +217,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
cache_length = self.key_cache[layer_idx].shape[-2]
return min(cache_length, self.window_length - 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was a bit of a hack, get_max_length makes us no longer need the hack :)

get_seq_length now always does what the fn name and the docstring say it does.

@@ -239,8 +250,8 @@ def update(
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
sin = cache_kwargs.get("sin")[: self.window_length]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slicing here is needed if more than one token is fed at once, after the cache is full.

kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]
Copy link
Member Author

@gante gante Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention mask must be sliced to match the length of key_states, which might have been sliced in .update() (for fixed-length caches)

@@ -187,3 +189,37 @@ def test_sink_cache_hard(self):
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))

@require_auto_gptq
def test_sink_cache_iterative_prompts(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test would fail on main

@tomaarsen
Copy link
Member

tomaarsen commented Dec 8, 2023

Some additional results @gante

image

This experiment is with calling the models with individual tokens exclusively. Using SinkCache makes the memory usage linear at a very low cost in perplexity.


Additionally, I'm doing a test with calling Mistral-7B in this case with this PR using multiple tokens at once. I took indices 0 to 63k of a book from pg19, only kept 20% of all indices, and then fed the model with the tokens between subsequent indices. The running cache is also included. The NLL loss is then converted to perplexity.
Note: We can't compare this with the perplexities from the previous graph: we should only try and observe whether the model eventually increases in perplexity.

image

The same script crashes on main. In this test, the perplexity stays constant, which is good.
Edit: I have now continued with more tests:
image

  • transformers_multi_attn_sink_1024_4_pr-27907_left_cache: This PR, with SinkCache(1024, 4).
  • transformers_multi_attn_sink_1024_4_pr-27907_right_cache: This PR, with SinkCache(1024, 4) & the change @ArthurZucker proposed regarding slicing the cache from the right.
  • transformers_multi_b31905d1: main without a special SinkCache.

The perplexity diverges quite heavily between the SinkCache and non, which is not ideal. Perhaps this is indicative of some error/bug, or perhaps not. It's a bit hard to tell. Beyond that, the left and right cache implementations behave identically (unless I made some measuring mistake), which is a bit odd. I don't have 100% confidence in this fix anymore I'm afraid.

  • Tom Aarsen

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@@ -239,8 +250,8 @@ def update(
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
sin = cache_kwargs.get("sin")[: self.window_length]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sin = cache_kwargs.get("sin")[: self.window_length]
sin = cache_kwargs.get("sin")[-self.window_length:]

would that not make more sense? since that's the side we split the cache from no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll run some tests for this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated my previous message with my findings

Copy link
Member Author

@gante gante Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker @tomaarsen that's the beauty of sin and cos: the rerotation applied on the sink caches is based on the relative angles, and the relative angles are the same regardless of the side we slice 🙌

If you place a debugger, you can see that the sliced tensors are different, but the rerotation coefficients are exactly the same!

Comment on lines 410 to 413
kv_max_length = past_key_value.get_max_length()
if kv_max_length is not None and kv_seq_len > kv_max_length:
kv_seq_len = kv_max_length
attention_mask = attention_mask[:, :, :, -kv_seq_len:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch not a fan of that.
We have something similar for the mistral sliced window but not in favor of keeping this. That should either go in the attention convert, which should slice it, or in the cache_kwargs as it's clearly sink cache and window cache specific 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative would be to make prepare_inputs_for_generation to prepare the sliced mask in advance (which is how TF/JAX do it). Going to check if it is feasible

@gante
Copy link
Member Author

gante commented Dec 8, 2023

@ArthurZucker as per your suggestion, I've reworked the PR to avoid post hoc attention_mask slicing -- there is a new function to get the usable cache length, and that function is used to obtain kv_seq_len

@tomaarsen the rework seems to have resulted in a qualitative result upgrade (e.g. see the test case), so I suspect that I've inadvertently fixed a bug 👀 Would you be able to rerun your benchmarks for SinkCache?

@tomaarsen
Copy link
Member

@gante I get ValueError: Attention weights should be of size (1, 32, 5, 1027), but is torch.Size([1, 32, 5, 1024]) upon running my modified script. Do you get an error like this with the multi-step generation script from your PR?

@gante
Copy link
Member Author

gante commented Dec 8, 2023

@tomaarsen no, the script in the PR header runs endlessly without issues 🤔 LMK if you can find a reproducer

@tomaarsen
Copy link
Member

I have the same, that script works fine. Hmmm

@gante
Copy link
Member Author

gante commented Dec 8, 2023

Got a reproducer: change max_new_tokens in the script above to 512 👀 having a look!

@gante
Copy link
Member Author

gante commented Dec 8, 2023

@tomaarsen should be fixed

@tomaarsen
Copy link
Member

@gante Works great now!
image

Red is the baseline, I can only run it to about ~15k seq length until my PC completely freezes.

@gante
Copy link
Member Author

gante commented Dec 8, 2023

🙌

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the prompt update

# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay better! thanks

@@ -268,7 +268,7 @@ def forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slice_length or window_length might be better? but a nit feel free to ignore

@gante gante merged commit ce0bbd5 into huggingface:main Dec 8, 2023
21 checks passed
@gante gante deleted the sink_multiple_tokens branch December 8, 2023 20:02
fxmarty added a commit to fxmarty/transformers that referenced this pull request Dec 8, 2023
fxmarty added a commit that referenced this pull request Dec 8, 2023
* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/bart/modeling_bart.py

Co-authored-by: Arthur <[email protected]>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <[email protected]>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------

Co-authored-by: Arthur <[email protected]>
@YJHMITWEB
Copy link

YJHMITWEB commented Mar 4, 2024

@tomaarsen @gante @ArthurZucker Hi, thanks for the work. Do you know how we can reproduce the benchmark test with the latest transformers repo implementation? Are there any scripts? Really appreciated!

@gante
Copy link
Member Author

gante commented Mar 5, 2024

@YJHMITWEB you can adapt the following scripts to your need. Note -- Llama + attention sinks is probably not working on the latest version, Mistral should be. We are reworking how cache works, things may be a bit rough :) Worst case scenario, use v4.36.

Script to gather perplexity
"""
Adapted from https://github.com/mit-han-lab/streaming-llm

Note: Although this script measures latency, it is not optimized whatsoever!
The latency is only tracked to see the impact of speed over time.

Usage:

python benchmark/perplexity.py --experiment attention_sinks
python benchmark/perplexity.py --experiment transformers
python benchmark/perplexity.py --experiment windowed
"""


import argparse
import itertools
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional

import pandas as pd
import torch
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

def compute_perplexity(
  model,
  tokenizer,
  dataset,
  experiment: str,
  output_dir: str = "outputs",
  data_column: str = "text",
  num_samples: int = 1,
  num_tokens: Optional[int] = None,
  overwrite: bool = False,
) -> None:
  output_dir = Path(output_dir)
  output_dir.mkdir(parents=True, exist_ok=True)
  output_file = output_dir / f"{experiment}.csv"

  if output_file.exists() and not overwrite:
      raise ValueError(
          f"The {output_file!r} output file already exists - if you really want to override it, then use `--overwrite`."
      )

  logs = defaultdict(list)
  loss_fn = CrossEntropyLoss(reduction="none")
  past_key_values = None
  num_processed_tokens = 0
  for text in itertools.islice(dataset, num_samples):
      encodings = tokenizer(text[data_column], return_tensors="pt")

      seq_len = encodings.input_ids.size(1)
      print(f"sequence length: {seq_len}")
      pbar = tqdm(range(0, seq_len - 1))

      for idx in pbar:
          start_t = time.time()
          input_ids = encodings.input_ids[:, idx : idx + 1].to(model.device)
          with torch.no_grad():
              outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
              logits = outputs.logits.view(-1, model.config.vocab_size)
              past_key_values = outputs.past_key_values
              label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
              neg_log_likelihood = loss_fn(logits, label)
              perplexity = neg_log_likelihood.exp()
          pbar.set_description(f"nll: {neg_log_likelihood.item():>5.2f}, ppl: {perplexity.item():>8.2f}")

          # Store data and save every 10 tokens
          logs["input_length"].append(idx + 1)
          logs["nll"].append(neg_log_likelihood.item())
          logs["ppl"].append(perplexity.item())
          logs["overall_ppl"].append(torch.tensor(logs["nll"]).mean().exp().item())
          logs["cuda_vram_allocated"].append(torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024)  # in GB
          logs["latency"].append(time.time() - start_t)
          if num_processed_tokens % 10 == 0:
              try:
                  pd.DataFrame(logs).to_csv(output_file, index=False)
              except KeyboardInterrupt as ex:
                  # If there's a Keyboard Interrupt, still write the file, and then stop
                  pd.DataFrame(logs).to_csv(output_file, index=False)
                  raise ex

          num_processed_tokens += 1
          if num_tokens and num_processed_tokens >= num_tokens:
              return


def main():
  parser = argparse.ArgumentParser()
  # How to call this experiment?
  parser.add_argument(
      "--experiment", type=str, default="main"
  )

  # Model args
  # parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
  parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-hf")
  parser.add_argument("--revision", type=str, default="main")
  parser.add_argument("--trust_remote_code", action="store_true")

  # Dataset args
  parser.add_argument("--dataset_name", type=str, default="emozilla/pg19-test")
  parser.add_argument("--data_column", type=str, default="text")
  parser.add_argument("--task", type=str, default=None)
  parser.add_argument("--split", type=str, default="test", choices=["validation", "test"])
  # parser.add_argument("--num_samples", type=int, default=1)
  parser.add_argument("--num_tokens", type=int, default=5000)
  parser.add_argument("--dtype", type=str, default="fp16")

  # Where to log
  parser.add_argument("--output_dir", type=str, default="/home/joao/joao_scripts/perplexity/outputs")
  parser.add_argument("--overwrite", action="store_true")

  args = parser.parse_args()

  if args.dtype == "fp16":
      dtype = torch.float16
  elif args.dtype == "fp32":
      dtype = torch.float32
  elif args.dtype == "bf16":
      dtype = torch.bfloat16
  else:
      raise ValueError(f"Unknown dtype: {args.dtype}")

  model = AutoModelForCausalLM.from_pretrained(
      args.model_name_or_path,
      revision=args.revision,
      trust_remote_code=bool(args.trust_remote_code),
      attn_implementation="eager",
      torch_dtype=dtype,
      device_map="auto",
  )
  model.eval()
  tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code))

  # Set up the dataset
  dataset = load_dataset(args.dataset_name, args.task, split=args.split, streaming=True)

  compute_perplexity(
      model,
      tokenizer,
      dataset,
      args.experiment,
      output_dir=args.output_dir,
      data_column=args.data_column,
      num_samples=1,  # <- No support for more than one instance now
      num_tokens=args.num_tokens,
      overwrite=args.overwrite,
  )


if __name__ == "__main__":
  main()
Script to plot perplexity (and other metrics)
"""
First run `perplexity.py` to generate one or more `csv` files.
This script can plot those csv files.

Usage:
python benchmark/plot_perplexity.py
python benchmark/plot_perplexity.py --features perplexity latency --title "Log perplexity & latency of Llama 2 7B as a function of input lengths"
"""

import argparse
from pathlib import Path
from typing import List, Optional

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

FEATURE_DF_MAP = {
  "perplexity": "overall_ppl",
  "vram": "cuda_vram_allocated",
  "latency": "latency",
}
FEATURE_STYLE_MAP = {
  "perplexity": "-",
  "vram": "--",
  "latency": ":",
}
FEATURE_LABEL_MAP = {
  "perplexity": "Perplexity (log), lower is better",
  "vram": "CUDA VRAM Usage (GB), lower is better",
  "latency": "Time per token (sec), lower is better",
}


def plot(
  features: List[str],
  output_dir: str = "outputs",
  title: Optional[str] = None,
  perplexity_limit: Optional[float] = None,
  skip_first: int = 100,
):
  output_dir = Path(output_dir)

  fig, ax = plt.subplots()
  ax.set_xlabel("Input Sequence Length")

  for feature_i, feature in enumerate(features):
      # If we already plotted on this ax, make a new one
      if feature_i:
          ax = ax.twinx()

      for file in output_dir.glob("*.csv"):
          experiment = file.stem
          df = pd.read_csv(file)
          X = df["input_length"][skip_first:]
          Y = df[FEATURE_DF_MAP[feature]][skip_first:]
          if feature == "perplexity":
              Y = np.log(Y)
          if feature == "latency":
              poly = np.polyfit(X, Y, 20)
              poly_y = np.poly1d(poly)(X)
              ax.plot(X, poly_y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")
          else:
              ax.plot(X, Y, FEATURE_STYLE_MAP[feature], label=f"{experiment} {feature}")

      ax.set_ylabel(FEATURE_LABEL_MAP[feature])
      if perplexity_limit and feature == "perplexity":
          ax.set_ylim(top=min(ax.get_ylim()[1], perplexity_limit))

      ax.legend(loc=[1, 2, 7][feature_i])  # upper right, upper left, center right

  ax.set_title(title.replace("\\n", "\n") if title else "Log perplexity as a function of input lengths")
  fig.tight_layout()

  return fig


def main():
  parser = argparse.ArgumentParser()
  # Where csv files have been logged
  parser.add_argument("--output_dir", type=str, default="/home/joao/joao_scripts/perplexity/outputs")
  parser.add_argument(
      "--features", choices=["perplexity", "vram", "latency"], nargs="+", default=["perplexity", "vram"]
  )
  parser.add_argument("--title", type=str, default=None)
  parser.add_argument("--log_perplexity_limit", type=float, default=5.0)
  # Perplexity starts a bit unstable, so we skip the start
  parser.add_argument("--skip_first", type=int, default=100)

  args = parser.parse_args()

  figure = plot(
      args.features,
      output_dir=args.output_dir,
      title=args.title,
      perplexity_limit=args.log_perplexity_limit,
      skip_first=args.skip_first,
  )

  # Add your own code here if you'd like to change the figure
  features = "_".join(args.features)
  save_path = f"/home/joao/joao_scripts/perplexity/outputs/plot_{features}.png"
  plt.savefig(save_path, dpi=600)
  print(f"plot saved to {save_path}")


if __name__ == "__main__":
  main()

Credits to @tomaarsen, who shared these with me :)

@YJHMITWEB
Copy link

@gante @tomaarsen Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants