Skip to content

Commit

Permalink
feat:whitelist(#584,#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
Guovin committed Dec 12, 2024
1 parent 75544ac commit 2979d7d
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 54 deletions.
4 changes: 4 additions & 0 deletions config/whitelist.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# 这是接口或订阅源的白名单,白名单内的接口或订阅源获取的接口将不会参与测速,优先排序至结果最前。
# 填写频道名称会直接保留该记录至最终结果,如:CCTV-1,接口地址,只填写接口地址则对所有频道生效,多条记录换行输入。
# This is the whitelist of the interface or subscription source. The interface in the whitelist or the interface obtained by the subscription source will not participate in the speed measurement and will be prioritized in the result.
# Filling in the channel name will directly retain the record to the final result, such as: CCTV-1, interface address, only fill in the interface address will be effective for all channels, multiple records newline input.
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
format_interval,
check_ipv6_support,
resource_path,
get_whitelist_urls
)


Expand Down Expand Up @@ -71,8 +72,9 @@ async def visit_page(self, channel_names=None):
if config.open_method[setting]:
if setting == "subscribe":
subscribe_urls = config.subscribe_urls
whitelist_urls = get_whitelist_urls()
task = asyncio.create_task(
task_func(subscribe_urls, callback=self.update_progress)
task_func(subscribe_urls, whitelist=whitelist_urls, callback=self.update_progress)
)
elif setting == "hotel_foodie" or setting == "hotel_fofa":
task = asyncio.create_task(task_func(callback=self.update_progress))
Expand Down
43 changes: 21 additions & 22 deletions updates/multicast/update_tmp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import sys
import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from updates.subscribe import get_channels_by_subscribe_urls
from driver.utils import get_soup_driver
from utils.config import config
import utils.constants as constants
from utils.channel import format_channel_name, get_name_url
from utils.tools import get_pbar_remaining, resource_path
from utils.channel import format_channel_name
from utils.tools import get_pbar_remaining, resource_path, get_name_url
import json

# import asyncio
from requests import Session
from collections import defaultdict
import re
from time import time
from tqdm import tqdm

Expand All @@ -40,7 +39,7 @@ def get_region_urls_from_IPTV_Multicast_source():
region_url[name]["移动"] = mobile
region_url[name]["电信"] = telecom
with open(
resource_path("updates/multicast/multicast_map.json"), "w", encoding="utf-8"
resource_path("updates/multicast/multicast_map.json"), "w", encoding="utf-8"
) as f:
json.dump(region_url, f, ensure_ascii=False, indent=4)

Expand All @@ -51,7 +50,7 @@ def get_multicast_urls_info_from_region_list():
"""
urls_info = []
with open(
resource_path("updates/multicast/multicast_map.json"), "r", encoding="utf-8"
resource_path("updates/multicast/multicast_map.json"), "r", encoding="utf-8"
) as f:
region_url = json.load(f)
urls_info = [
Expand All @@ -71,9 +70,9 @@ async def get_multicast_region_result():
multicast_region_urls_info, multicast=True
)
with open(
resource_path("updates/multicast/multicast_region_result.json"),
"w",
encoding="utf-8",
resource_path("updates/multicast/multicast_region_result.json"),
"w",
encoding="utf-8",
) as f:
json.dump(multicast_result, f, ensure_ascii=False, indent=4)

Expand All @@ -83,7 +82,7 @@ def get_multicast_region_type_result_txt():
Get multicast region type result txt
"""
with open(
resource_path("updates/multicast/multicast_map.json"), "r", encoding="utf-8"
resource_path("updates/multicast/multicast_map.json"), "r", encoding="utf-8"
) as f:
region_url = json.load(f)
session = Session()
Expand All @@ -92,9 +91,9 @@ def get_multicast_region_type_result_txt():
response = session.get(url)
content = response.text
with open(
resource_path(f"config/rtp/{region}_{type}.txt"),
"w",
encoding="utf-8",
resource_path(f"config/rtp/{region}_{type}.txt"),
"w",
encoding="utf-8",
) as f:
f.write(content)

Expand All @@ -109,11 +108,11 @@ def get_multicast_region_result_by_rtp_txt(callback=None):
filename.rsplit(".", 1)[0]
for filename in os.listdir(rtp_path)
if filename.endswith(".txt")
and "_" in filename
and (
filename.rsplit(".", 1)[0].partition("_")[0] in config_region_list
or config_region_list & {"all", "ALL", "全部"}
)
and "_" in filename
and (
filename.rsplit(".", 1)[0].partition("_")[0] in config_region_list
or config_region_list & {"all", "ALL", "全部"}
)
]

total_files = len(rtp_file_list)
Expand All @@ -127,7 +126,7 @@ def get_multicast_region_result_by_rtp_txt(callback=None):
for filename in rtp_file_list:
region, _, type = filename.partition("_")
with open(
os.path.join(rtp_path, f"{filename}.txt"), "r", encoding="utf-8"
os.path.join(rtp_path, f"{filename}.txt"), "r", encoding="utf-8"
) as f:
for line in f:
name_url = get_name_url(line, pattern=constants.rtp_pattern)
Expand All @@ -146,9 +145,9 @@ def get_multicast_region_result_by_rtp_txt(callback=None):
)

with open(
resource_path("updates/multicast/multicast_region_result.json"),
"w",
encoding="utf-8",
resource_path("updates/multicast/multicast_region_result.json"),
"w",
encoding="utf-8",
) as f:
json.dump(multicast_result, f, ensure_ascii=False, indent=4)

Expand Down
30 changes: 18 additions & 12 deletions updates/subscribe/request.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
import utils.constants as constants
from tqdm.asyncio import tqdm_asyncio
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from time import time

from requests import Session, exceptions
from utils.config import config
from tqdm.asyncio import tqdm_asyncio

import utils.constants as constants
from utils.channel import format_channel_name
from utils.config import config
from utils.retry import retry_func
from utils.channel import get_name_url, format_channel_name
from utils.tools import (
merge_objects,
get_pbar_remaining,
format_url_with_cache,
add_url_info,
get_name_url
)
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict


async def get_channels_by_subscribe_urls(
urls,
multicast=False,
hotel=False,
retry=True,
error_print=True,
callback=None,
urls,
multicast=False,
hotel=False,
retry=True,
error_print=True,
whitelist=None,
callback=None,
):
"""
Get the channels by subscribe urls
Expand Down Expand Up @@ -53,6 +56,7 @@ def process_subscribe_channels(subscribe_info):
else:
subscribe_url = subscribe_info
channels = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
in_whitelist = whitelist and (subscribe_url in whitelist)
try:
response = None
try:
Expand Down Expand Up @@ -95,6 +99,8 @@ def process_subscribe_channels(subscribe_info):
else f"{subscribe_name}"
)
)
if in_whitelist:
info = "!" + info
url = add_url_info(url, info)
url = format_url_with_cache(
url, cache=subscribe_url if (multicast or hotel) else None
Expand Down
35 changes: 17 additions & 18 deletions utils/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,20 @@
sort_urls,
)
from utils.tools import (
get_name_url,
check_url_by_patterns,
get_total_urls,
process_nested_dict,
add_url_info,
remove_cache_info,
resource_path,
write_content_into_txt,
get_whitelist_urls,
get_whitelist_name_urls
)


def get_name_url(content, pattern, multiline=False, check_url=True):
"""
Get channel name and url from content
"""
flag = re.MULTILINE if multiline else 0
matches = re.findall(pattern, content, flag)
channels = [
{"name": match[0].strip(), "url": match[1].strip()}
for match in matches
if (check_url and match[1].strip()) or not check_url
]
return channels


def get_channel_data_from_file(channels, file, use_old):
def get_channel_data_from_file(channels, file, use_old, whitelist):
"""
Get the channel data from the file
"""
Expand All @@ -61,6 +50,9 @@ def get_channel_data_from_file(channels, file, use_old):
category_dict = channels[current_category]
if name not in category_dict:
category_dict[name] = []
if name in whitelist:
for whitelist_url in whitelist[name]:
category_dict[name].append((whitelist_url, None, None, "important"))
if use_old and url:
info = url.partition("$")[2]
origin = None
Expand All @@ -78,11 +70,15 @@ def get_channel_items():
"""
user_source_file = resource_path(config.source_file)
channels = defaultdict(lambda: defaultdict(list))
whitelist = get_whitelist_name_urls()
whitelist_len = len(list(whitelist.keys()))
if whitelist_len:
print(f"Found {whitelist_len} channel in whitelist")

if os.path.exists(user_source_file):
with open(user_source_file, "r", encoding="utf-8") as file:
channels = get_channel_data_from_file(
channels, file, config.open_use_old_result
channels, file, config.open_use_old_result, whitelist
)

if config.open_use_old_result:
Expand Down Expand Up @@ -551,7 +547,10 @@ async def process_sort_channel_list(data, ipv6=False, callback=None):
"""
ipv6_proxy = None if (not config.open_ipv6 or ipv6) else constants.ipv6_proxy
need_sort_data = copy.deepcopy(data)
process_nested_dict(need_sort_data, seen=set(), flag=r"cache:(.*)", force_str="!")
whitelist_urls = get_whitelist_urls()
if whitelist_urls:
print(f"Found {len(whitelist_urls)} whitelist urls")
process_nested_dict(need_sort_data, seen=set(whitelist_urls), flag=r"cache:(.*)", force_str="!")
result = {}
semaphore = asyncio.Semaphore(10)

Expand All @@ -574,7 +573,7 @@ async def limited_get_speed(info, ipv6_proxy, callback):
await asyncio.gather(*tasks)
for cate, obj in data.items():
for name, info_list in obj.items():
info_list = sort_urls(name, info_list)
info_list = sort_urls(name, info_list, whitelist=whitelist_urls)
append_data_to_info_data(
result,
cate,
Expand Down
1 change: 1 addition & 0 deletions utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_resolution_value(resolution_str):
class ConfigManager:

def __init__(self):
self.config = None
self.load()

def __getattr__(self, name, *args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

output_path = "output"

whitelist_path = os.path.join(config_path, "whitelist.txt")

result_path = os.path.join(output_path, "result_new.txt")

cache_path = os.path.join(output_path, "cache.pkl")
Expand Down
4 changes: 3 additions & 1 deletion utils/speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,16 @@ async def get_speed(url, ipv6_proxy=None, callback=None):
callback()


def sort_urls(name, data, logger=None):
def sort_urls(name, data, logger=None, whitelist=None):
"""
Sort the urls with info
"""
filter_data = []
if logger is None:
logger = get_logger(constants.sort_log_path, level=INFO, init=True)
for url, date, resolution, origin in data:
if whitelist and url in whitelist:
origin = "important"
result = {
"url": remove_cache_info(url),
"date": date,
Expand Down
50 changes: 50 additions & 0 deletions utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import socket
import sys
import urllib.parse
from collections import defaultdict
from logging.handlers import RotatingFileHandler
from time import time

Expand Down Expand Up @@ -525,3 +526,52 @@ def write_content_into_txt(content, path=None, newline=True, callback=None):

if callback:
callback()


def get_name_url(content, pattern, multiline=False, check_url=True):
"""
Get name and url from content
"""
flag = re.MULTILINE if multiline else 0
matches = re.findall(pattern, content, flag)
channels = [
{"name": match[0].strip(), "url": match[1].strip()}
for match in matches
if (check_url and match[1].strip()) or not check_url
]
return channels


def get_whitelist_urls():
"""
Get the whitelist urls
"""
whitelist_file = resource_path(constants.whitelist_path)
urls = []
url_pattern = constants.url_pattern
if os.path.exists(whitelist_file):
with open(whitelist_file, "r", encoding="utf-8") as f:
for line in f:
match = re.search(url_pattern, line)
if match:
urls.append(match.group().strip())
return urls


def get_whitelist_name_urls():
"""
Get the whitelist name urls
"""
whitelist_file = resource_path(constants.whitelist_path)
name_urls = defaultdict(list)
txt_pattern = constants.txt_pattern
if os.path.exists(whitelist_file):
with open(whitelist_file, "r", encoding="utf-8") as f:
for line in f:
name_url = get_name_url(line, pattern=txt_pattern)
if name_url and name_url[0]:
name = name_url[0]["name"]
url = name_url[0]["url"]
if url not in name_urls[name]:
name_urls[name].append(url)
return name_urls

0 comments on commit 2979d7d

Please sign in to comment.