diff --git a/prediction_prophet/functions/research.py b/prediction_prophet/functions/research.py index c8daf20f..fa171da1 100644 --- a/prediction_prophet/functions/research.py +++ b/prediction_prophet/functions/research.py @@ -1,5 +1,6 @@ import logging import typing as t +from math import ceil from langchain.text_splitter import RecursiveCharacterTextSplitter from prediction_prophet.functions.create_embeddings_from_results import create_embeddings_from_results @@ -35,7 +36,7 @@ def research( initial_subqueries_limit: int = 20, subqueries_limit: int = 4, max_results_per_search: int = 5, - min_scraped_sites: int = 0, + min_scraped_sites_ratio: float = 0.0, scrape_content_split_chunk_size: int = 800, scrape_content_split_chunk_overlap: int = 225, top_k_per_query: int = 8, @@ -45,14 +46,6 @@ def research( logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(), tavily_storage: TavilyStorage | None = None, ) -> Research: - # Validate args - if min_scraped_sites > max_results_per_search * subqueries_limit: - raise ValueError( - f"min_scraped_sites ({min_scraped_sites}) must be less than or " - f"equal to max_results_per_search ({max_results_per_search}) * " - f"subqueries_limit ({subqueries_limit})." - ) - logger.info("Started subqueries generation") all_queries = generate_subqueries(query=goal, limit=initial_subqueries_limit, model=model, temperature=temperature, api_key=openai_api_key) @@ -93,13 +86,13 @@ def research( scraped = [result for result in scraped if result.content != ""] unique_scraped_websites = set([result.url for result in scraped]) - if len(scraped) < min_scraped_sites: + if min_scraped_sites_ratio and len(unique_scraped_websites) < ceil(min_scraped_sites_ratio * len(websites_to_scrape)): # Get urls that were not scraped raise ValueError( f"Only successfully scraped content from " f"{len(unique_scraped_websites)} websites, out of a possible " f"{len(websites_to_scrape)} websites, which is less than the " - f"minimum required ({min_scraped_sites}). The following websites " + f"minimum required ({min_scraped_sites_ratio}). The following websites " f"were not scraped: {websites_to_scrape - unique_scraped_websites}" )