-
Notifications
You must be signed in to change notification settings - Fork 15.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(add): LLM integration of Cloudflare Workers AI
- Loading branch information
Showing
3 changed files
with
300 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Cloudflare Workers AI\n", | ||
"\n", | ||
"[Cloudflare AI document](https://developers.cloudflare.com/workers-ai/models/text-generation/) listed all text embeddings models available.\n", | ||
"\n", | ||
"Both Cloudflare account ID and API token are required. Find how to obtain them from [this document](https://developers.cloudflare.com/workers-ai/get-started/rest-api/)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.chains import LLMChain\n", | ||
"from langchain.llms.cloudflare_workersai import CloudflareWorkersAI\n", | ||
"from langchain.prompts import PromptTemplate\n", | ||
"\n", | ||
"template = \"\"\"Human: {question}\n", | ||
"\n", | ||
"AI Assistant: \"\"\"\n", | ||
"\n", | ||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Get authentication before running LLM." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import getpass\n", | ||
"\n", | ||
"my_account_id = getpass.getpass(\"Enter your Cloudflare account ID:\\n\\n\")\n", | ||
"my_api_token = getpass.getpass(\"Enter your Cloudflare API token:\\n\\n\")\n", | ||
"llm = CloudflareWorkersAI(account_id=my_account_id, api_token=my_api_token)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"\"AI Assistant: Ah, a fascinating question! The answer to why roses are red is a bit complex, but I'll do my best to explain it in a simple and polite manner.\\nRoses are red due to the presence of a pigment called anthocyanin. Anthocyanin is a type of flavonoid, a class of plant compounds that are responsible for the red, purple, and blue colors found in many fruits and vegetables.\\nNow, you might be wondering why roses specifically have this pigment. The answer lies in the evolutionary history of roses. You see, roses have been around for millions of years, and their red color has likely played a crucial role in attracting pollinators like bees and butterflies. These pollinators are drawn to the bright colors of roses, which helps the plants reproduce and spread their seeds.\\nSo, to summarize, the reason roses are red is because of the anthocyanin pigment, which is a result of millions of years of evolutionary pressures shaping the plant's coloration to attract pollinators. I hope that helps clarify things for\"" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n", | ||
"\n", | ||
"question = \"Why are roses red?\"\n", | ||
"llm_chain.run(question)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Ah | , | a | most | excellent | question | , | my | dear | human | ! | * | ad | just | s | glass | es | * | The | sky | appears | blue | due | to | a | phenomen | on | known | as | Ray | le | igh | scatter | ing | . | When | sun | light | enters | Earth | ' | s | atmosphere | , | it | enc | oun | ters | tiny | mole | cules | of | g | ases | such | as | nit | ro | gen | and | o | xygen | . | These | mole | cules | scatter | the | light | in | all | directions | , | but | they | scatter | shorter | ( | blue | ) | w | avel | ength | s | more | than | longer | ( | red | ) | w | avel | ength | s | . | This | is | known | as | Ray | le | igh | scatter | ing | . | \n", | ||
" | As | a | result | , | the | blue | light | is | dispers | ed | throughout | the | atmosphere | , | giving | the | sky | its | characteristic | blue | h | ue | . | The | blue | light | appears | more | prominent | during | sun | r | ise | and | sun | set | due | to | the | scatter | ing | of | light | by | the | Earth | ' | s | atmosphere | at | these | times | . | \n", | ||
" | I | hope | this | explanation | has | been | helpful | , | my | dear | human | ! | Is | there | anything | else | you | would | like | to | know | ? | * | sm | iles | * | * | | " | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Using streaming\n", | ||
"for chunk in llm.stream(\"Why is sky blue?\"):\n", | ||
" print(chunk, end=\" | \", flush=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.18" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import json | ||
import logging | ||
from typing import Any, Dict, Iterator, List, Optional | ||
|
||
import requests | ||
from langchain_core.outputs import GenerationChunk | ||
|
||
from langchain.callbacks.manager import CallbackManagerForLLMRun | ||
from langchain.llms.base import LLM | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CloudflareWorkersAI(LLM): | ||
"""Langchain LLM class to help to access Cloudflare Workers AI service. | ||
To use, you must provide an API token and | ||
account ID to access Cloudflare Workers AI, and | ||
pass it as a named parameter to the constructor. | ||
Example: | ||
.. code-block:: python | ||
from langchain.llms.cloudflare_workersai import CloudflareWorkersAI | ||
my_account_id = "my_account_id" | ||
my_api_token = "my_secret_api_token" | ||
llm_model = "@cf/meta/llama-2-7b-chat-int8" | ||
cf_ai = CloudflareWorkersAI( | ||
account_id=my_account_id, | ||
api_token=my_api_token, | ||
model=llm_model | ||
) | ||
""" | ||
|
||
account_id: str | ||
api_token: str | ||
model: str = "@cf/meta/llama-2-7b-chat-int8" | ||
base_url: str = "https://api.cloudflare.com/client/v4/accounts" | ||
streaming: bool = False | ||
endpoint_url: str = "" | ||
|
||
def __init__(self, **kwargs: Any) -> None: | ||
"""Initialize the Cloudflare Workers AI class.""" | ||
super().__init__(**kwargs) | ||
|
||
self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}" | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
"""Return type of LLM.""" | ||
return "cloudflare" | ||
|
||
@property | ||
def _default_params(self) -> Dict[str, Any]: | ||
"""Default parameters""" | ||
return {} | ||
|
||
@property | ||
def _identifying_params(self) -> Dict[str, Any]: | ||
"""Identifying parameters""" | ||
return { | ||
"account_id": self.account_id, | ||
"api_token": self.api_token, | ||
"model": self.model, | ||
"base_url": self.base_url, | ||
} | ||
|
||
def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response: | ||
"""Call Cloudflare Workers API""" | ||
headers = {"Authorization": f"Bearer {self.api_token}"} | ||
data = {"prompt": prompt, "stream": self.streaming, **params} | ||
response = requests.post(self.endpoint_url, headers=headers, json=data) | ||
return response | ||
|
||
def _process_response(self, response: requests.Response) -> str: | ||
"""Process API response""" | ||
if response.ok: | ||
data = response.json() | ||
return data["result"]["response"] | ||
else: | ||
raise ValueError(f"Request failed with status {response.status_code}") | ||
|
||
def _stream( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> Iterator[GenerationChunk]: | ||
"""Streaming prediction""" | ||
original_steaming: bool = self.streaming | ||
self.streaming = True | ||
_response_prefix_count = len("data: ") | ||
_response_stream_end = b"data: [DONE]" | ||
for chunk in self._call_api(prompt, kwargs).iter_lines(): | ||
if chunk == _response_stream_end: | ||
break | ||
if len(chunk) > _response_prefix_count: | ||
try: | ||
data = json.loads(chunk[_response_prefix_count:]) | ||
except Exception as e: | ||
logger.debug(chunk) | ||
raise e | ||
if data is not None and "response" in data: | ||
yield GenerationChunk(text=data["response"]) | ||
if run_manager: | ||
run_manager.on_llm_new_token(data["response"]) | ||
logger.debug("stream end") | ||
self.streaming = original_steaming | ||
|
||
def _call( | ||
self, | ||
prompt: str, | ||
stop: Optional[List[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> str: | ||
"""Regular prediction""" | ||
if self.streaming: | ||
return "".join( | ||
[c.text for c in self._stream(prompt, run_manager, **kwargs)] | ||
) | ||
else: | ||
response = self._call_api(prompt, kwargs) | ||
return self._process_response(response) |
46 changes: 46 additions & 0 deletions
46
libs/langchain/tests/integration_tests/llms/test_cloudflare_workersai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import responses | ||
|
||
from langchain.llms.cloudflare_workersai import CloudflareWorkersAI | ||
|
||
|
||
@responses.activate | ||
def test_cloudflare_workersai_call() -> None: | ||
responses.add( | ||
responses.POST, | ||
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8", | ||
json={"result": {"response": "4"}}, | ||
status=200, | ||
) | ||
|
||
llm = CloudflareWorkersAI( | ||
account_id="my_account_id", | ||
api_token="my_api_token", | ||
model="@cf/meta/llama-2-7b-chat-int8", | ||
) | ||
output = llm("What is 2 + 2?") | ||
|
||
assert output == "4" | ||
|
||
|
||
@responses.activate | ||
def test_cloudflare_workersai_stream() -> None: | ||
response_body = ['data: {"response": "Hello"}', "data: [DONE]"] | ||
responses.add( | ||
responses.POST, | ||
"https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8", | ||
body="\n".join(response_body), | ||
status=200, | ||
) | ||
|
||
llm = CloudflareWorkersAI( | ||
account_id="my_account_id", | ||
api_token="my_api_token", | ||
model="@cf/meta/llama-2-7b-chat-int8", | ||
streaming=True, | ||
) | ||
|
||
outputs = [] | ||
for chunk in llm.stream("Say Hello"): | ||
outputs.append(chunk) | ||
|
||
assert "".join(outputs) == "Hello" |