Skip to content

Commit

Permalink
Merge pull request #18 from webcoderz/main
Browse files Browse the repository at this point in the history
fixing civitai nodes to accept civitai api key
  • Loading branch information
WASasquatch authored Jun 9, 2024
2 parents a8a0062 + 293ca26 commit e839c45
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
18 changes: 10 additions & 8 deletions CivitAI_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CivitAI_Model:
debug_response = False
warning = False

def __init__(self, model_id, save_path, model_paths, model_types=[], model_version=None, download_chunks=None, max_download_retries=None, warning=True, debug_response=False):
def __init__(self, model_id, save_path, model_paths, model_types=[], token=None, model_version=None, download_chunks=None, max_download_retries=None, warning=True, debug_response=False):
self.model_id = model_id
self.version = model_version
self.type = None
Expand All @@ -53,6 +53,9 @@ def __init__(self, model_id, save_path, model_paths, model_types=[], model_versi
if download_chunks:
self.num_chunks = int(download_chunks)

if token:
self.token = token

if max_download_retries:
self.max_retries = int(max_download_retries)

Expand Down Expand Up @@ -90,13 +93,13 @@ def details(self):
if file_id and file_id == file_version:
self.name = name
self.name_friendly = file.get('name_friendly')
self.download_url = file.get('downloadUrl')
self.download_url = f"{file.get('downloadUrl')}?token={self.token}"
self.trained_words = file.get('trained_words')
self.file_details = file
self.file_id = file_version
self.model_id = self.model_id
self.version = int(file.get('id'))
self.type = filget.get('model_type', 'Model')
self.type = file.get('model_type', 'Model')
self.file_size = file.get('sizeKB', 0) * 1024
hashes = file.get('hashes')
if hashes:
Expand Down Expand Up @@ -154,7 +157,7 @@ def details(self):
for file in files:
download_url = file.get('downloadUrl')
if download_url == model_download_url:
self.download_url = download_url
self.download_url = download_url + f"?token={self.token}"
self.file_details = file
self.file_id = file.get('id')
self.name = file.get('name')
Expand All @@ -175,7 +178,7 @@ def details(self):
for file in files:
download_url = file.get('downloadUrl')
if download_url == model_download_url:
self.download_url = download_url
self.download_url = download_url + f"?token={self.token}"
self.file_details = file
self.file_id = file.get('id')
self.name = file.get('name')
Expand Down Expand Up @@ -307,11 +310,11 @@ def get_total_file_size(url):

response = requests.head(self.download_url)
total_file_size = total_file_size = get_total_file_size(self.download_url)

response = requests.get(self.download_url, stream=True)
if response.status_code != requests.codes.ok:
raise Exception(f"{ERR_PREFIX}Failed to download {self.type} file from CivitAI. Status code: {response.status_code}")

with open(save_path, 'wb') as file:
file.seek(total_file_size - 1)
file.write(b'\0')
Expand All @@ -333,7 +336,6 @@ def get_total_file_size(url):
future.result()

total_pbar.close()

model_sha256 = CivitAI_Model.calculate_sha256(save_path)
if model_sha256 == self.file_sha256:
print(f"{MSG_PREFIX}Loading {self.type}: {self.name} (https://civitai.com/models/{self.model_id}/?modelVersionId={self.version})")
Expand Down
9 changes: 5 additions & 4 deletions civitai_checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def INPUT_TYPES(cls):
"ckpt_name": (checkpoints,),
},
"optional": {
"api_key": ("STRING", {"default": None, "multiline": False}),
"download_chunks": ("INT", {"default": 4, "min": 1, "max": 12, "step": 1}),
"download_path": (list(checkpoint_paths.keys()),),
},
Expand All @@ -53,7 +54,7 @@ def INPUT_TYPES(cls):

CATEGORY = "CivitAI/Loaders"

def load_checkpoint(self, ckpt_air, ckpt_name, download_chunks=None, download_path=None, extra_pnginfo=None):
def load_checkpoint(self, ckpt_air, ckpt_name, api_key=None, download_chunks=None, download_path=None, extra_pnginfo=None):

if extra_pnginfo:
if not extra_pnginfo['workflow']['extra'].__contains__('ckpt_airs'):
Expand Down Expand Up @@ -82,7 +83,7 @@ def load_checkpoint(self, ckpt_air, ckpt_name, download_chunks=None, download_pa
else:
download_path = CHECKPOINTS[0]

civitai_model = CivitAI_Model(model_id=ckpt_id, model_version=version_id, model_types=["Checkpoint",], save_path=download_path, model_paths=CHECKPOINTS, download_chunks=download_chunks)
civitai_model = CivitAI_Model(model_id=ckpt_id, model_version=version_id, model_types=["Checkpoint",], token=api_key, save_path=download_path, model_paths=CHECKPOINTS, download_chunks=download_chunks)

if not civitai_model.download():
return None, None, None
Expand All @@ -107,5 +108,5 @@ def load_checkpoint(self, ckpt_air, ckpt_name, download_chunks=None, download_pa
print(f"{MSG_PREFIX}Loading checkpoint from disk: {ckpt_path}")

out = self.ckpt_loader.load_checkpoint(ckpt_name=ckpt_name)

return out[0], out[1], out[2], { "extra_pnginfo": extra_pnginfo }
return out[0], out[1], out[2], { "extra_pnginfo": extra_pnginfo }
7 changes: 4 additions & 3 deletions civitai_lora_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def INPUT_TYPES(cls):

},
"optional": {
"api_key": ("STRING", {"default": None, "multiline": False}),
"download_chunks": ("INT", {"default": 4, "min": 1, "max": 12, "step": 1}),
"download_path": (list(lora_paths),),
},
Expand All @@ -57,7 +58,7 @@ def INPUT_TYPES(cls):

CATEGORY = "CivitAI/Loaders"

def load_lora(self, model, clip, lora_air, lora_name, strength_model, strength_clip, download_chunks=None, download_path=None, extra_pnginfo=None):
def load_lora(self, model, clip, lora_air, lora_name, strength_model, strength_clip, api_key=None, download_chunks=None, download_path=None, extra_pnginfo=None):

if extra_pnginfo:
if not extra_pnginfo['workflow']['extra'].__contains__('lora_airs'):
Expand Down Expand Up @@ -86,7 +87,7 @@ def load_lora(self, model, clip, lora_air, lora_name, strength_model, strength_c
else:
download_path = LORAS[0]

civitai_model = CivitAI_Model(model_id=lora_id, model_version=version_id, model_types=["LORA", "LoCon"], save_path=download_path, model_paths=LORAS, download_chunks=download_chunks)
civitai_model = CivitAI_Model(model_id=lora_id, model_version=version_id, model_types=["LORA", "LoCon"], token=api_key, save_path=download_path, model_paths=LORAS, download_chunks=download_chunks)

if not civitai_model.download():
return model, clip
Expand All @@ -112,4 +113,4 @@ def load_lora(self, model, clip, lora_air, lora_name, strength_model, strength_c

model_lora, clip_lora = self.lora_loader.load_lora(model, clip, lora_name, strength_model, strength_clip)

return model_lora, clip_lora, { "extra_pnginfo": extra_pnginfo }
return model_lora, clip_lora, { "extra_pnginfo": extra_pnginfo }

0 comments on commit e839c45

Please sign in to comment.