diff --git a/.github/ISSUE_TEMPLATE/usage-question.md b/.github/ISSUE_TEMPLATE/usage-question.md index 9c6edda..8ef1cba 100644 --- a/.github/ISSUE_TEMPLATE/usage-question.md +++ b/.github/ISSUE_TEMPLATE/usage-question.md @@ -1,6 +1,6 @@ --- name: Usage Question -about: Ask a question about text2vec usage +about: Ask a question about usage title: '' labels: question assignees: '' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fa4f84a..6d66952 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Contributing -We are happy to accept your contributions to make `text2vec` better and more awesome! To avoid unnecessary work on either +We are happy to accept your contributions to make this repo better and more awesome! To avoid unnecessary work on either side, please stick to the following process: 1. Check if there is already [an issue](https://github.com/shibing624/similarities/issues) for your concern. diff --git a/README.md b/README.md index 455ff64..126c1cc 100644 --- a/README.md +++ b/README.md @@ -12,86 +12,37 @@ Similarities is a toolkit for similarity calculation and semantic search, suppor similarities:相似度计算、语义匹配搜索工具包。 -**similarities** 实现了多种相似度计算、匹配搜索算法,支持文本、图像,python3开发,pip安装,开箱即用。 +**similarities** 实现了多种相似度计算、语义匹配检索算法,支持亿级数据文搜文、文搜图、图搜图,python3开发,pip安装,开箱即用。 **Guide** -- [Feature](#Feature) -- [Evaluation](#Evaluation) +- [Features](#Features) - [Install](#install) - [Usage](#usage) - [Contact](#Contact) -- [Reference](#reference) +- [Acknowledgements](#Acknowledgements) -# Feature +## Features -### 文本相似度计算(文本匹配) -- 余弦相似(Cosine Similarity):两向量求余弦 -- 点积(Dot Product):两向量归一化后求内积 -- 汉明距离(Hamming Distance),编辑距离(Levenshtein Distance),欧氏距离(Euclidean Distance),曼哈顿距离(Manhattan Distance)等 +### 文本相似度计算 + 文本搜索 -#### 语义模型 -- [CoSENT文本匹配模型](https://github.com/shibing624/similarities/blob/main/similarities/similarity.py#L79)【推荐】 -- BERT模型(文本向量表征) -- SentenceBERT文本匹配模型 +- 语义匹配模型【推荐】:本项目基于text2vec实现了CoSENT模型的文本相似度计算和文本搜索,支持中英文、多语言多种SentenceBERT类预训练模型,支持 Cos Similarity/Dot Product/Hamming Distance/Euclidean Distance 等多种相似度计算方法,支持 SemanticSearch/Faiss/Annoy/Hnsw 等多种文本搜索算法,支持亿级数据高效检索 +- 字面匹配模型:本项目实现了Word2Vec、BM25、RankBM25、TFIDF、SimHash、同义词词林、知网Hownet义原匹配等多种字面匹配模型 -#### 字面模型 -- [Word2Vec文本浅层语义表征](https://github.com/shibing624/similarities/blob/main/similarities/literalsim.py#L374)【推荐】 -- 同义词词林 -- 知网Hownet义原匹配 -- BM25、RankBM25 -- TFIDF -- SimHash +### 图像相似度计算/图文相似度计算 + 图搜图/文搜图 +- 英文CLIP(Contrastive Language-Image Pre-Training)模型:OpenAI提出的图文匹配模型,可用于图文特征(embeddings)、相似度计算、图文检索、零样本图片分类,本项目实现了[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)等CLIP系列模型的图文检索功能 +- 中文CLIP模型【推荐】:阿里使用~2亿图文对训练,发布的中文CLIP模型,支持[OFA-Sys/chinese-clip-vit-base-patch16](https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16)等Chinese-CLIP系列模型,本项目基于PyTorch实现了中文CLIP模型的向量表征、构建索引(基于autofaiss)、批量检索、后台服务(基于Fastapi)、前端展现(基于gradio)功能 +- 图像特征提取:本项目基于cv2实现了pHash、dHash、wHash、aHash、SIFT等多种图像特征提取算法 -### 图像相似度计算(图像匹配) -#### 语义模型 -- [CLIP(Contrastive Language-Image Pre-Training)](https://github.com/shibing624/similarities/blob/main/similarities/imagesim.py#L25) -- VGG(doing) -- ResNet(doing) +## Demo -#### 特征提取 -- [pHash](https://github.com/shibing624/similarities/blob/main/similarities/imagesim.py#L164)【推荐】, dHash, wHash, aHash -- SIFT, Scale Invariant Feature Transform(SIFT) -- SURF, Speeded Up Robust Features(SURF)(doing) - -### 图文相似度计算 -- [CLIP(Contrastive Language-Image Pre-Training)](https://github.com/shibing624/similarities/blob/main/similarities/imagesim.py#L25) - -### 匹配搜索 -- [SemanticSearch](https://github.com/shibing624/similarities/blob/main/similarities/similarity.py#L185):向量相似检索,使用Cosine - Similarty + topk高效计算,比一对一暴力计算快一个数量级 - -# Demo - -Compute similarity score Demo: https://huggingface.co/spaces/shibing624/text2vec - -Semantic Search Demo: https://huggingface.co/spaces/shibing624/similarities +Text Search Demo: https://huggingface.co/spaces/shibing624/similarities ![](docs/hf_search.png) -# Evaluation -### 文本匹配和文本检索 -#### 中文文本匹配模型评测结果 - -| Model | ATEC | BQ | LCQMC | PAWSX | STS-B | Avg | QPS | -| :---- | :-: | :-: | :-: | :-: | :-: | :-: | :-: | -| Word2Vec | 20.00 | 31.49 | 59.46 | 2.57 | 55.78 | 33.86 | 10283 | -| SBERT-multi | 18.42 | 38.52 | 63.96 | 10.14 | 78.90 | 41.99 | 2371 | -| Text2vec | 31.93 | 42.67 | 70.16 | 17.21 | 79.30 | **48.25** | 2572 | - -> 结果值使用spearman系数 - -Model: -- Cilin -- Hownet -- SimHash -- TFIDF - - - -# Install +## Install ``` pip3 install torch # conda install pytorch @@ -106,279 +57,58 @@ cd similarities python3 setup.py install ``` -# Usage +## Usage -### 1. 文本语义相似度计算 +### 1. 文本相似度计算 example: [examples/text_similarity_demo.py](https://github.com/shibing624/similarities/blob/main/examples/text_similarity_demo.py) ```python -from similarities import Similarity +from similarities import BertSimilarity -m = Similarity() +m = BertSimilarity(model_name_or_path="shibing624/text2vec-base-chinese") r = m.similarity('如何更换花呗绑定银行卡', '花呗更改绑定银行卡') print(f"similarity score: {float(r)}") # similarity score: 0.855146050453186 ``` -Similarity的默认方法: -```python -Similarity(corpus: Union[List[str], Dict[str, str]] = None, - model_name_or_path="shibing624/text2vec-base-chinese", - max_seq_length=128) -``` - -- 返回值:余弦值`score`范围是[-1, 1],值越大越相似 -- `corpus`:搜索用的doc集,仅搜索时需要,输入格式:句子列表`List[str]`或者{corpus_id: sentence}的`Dict[str, str]`格式 -- `model_name_or_path`:模型名称或者模型路径,默认会从HF model hub下载并使用中文语义匹配模型[shibing624/text2vec-base-chinese](https://huggingface.co/shibing624/text2vec-base-chinese),如果是多语言景,可以替换为多语言匹配模型[shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual) -- `max_seq_length`:输入句子的最大长度,最大为匹配模型支持的最大长度,BERT系列是512 +### 2. 文本搜索 -### 2. 文本语义匹配搜索 - -一般在文档候选集中找与query最相似的文本,常用于QA场景的问句相似匹配、文本相似检索等任务。 +一般在文档候选集中找与query最相似的文本,常用于QA场景的问句相似匹配、文本搜索(百万内数据集)等任务。 example: [examples/text_semantic_search_demo.py](https://github.com/shibing624/similarities/blob/main/examples/text_semantic_search_demo.py) -```python -import sys - -sys.path.append('..') -from similarities import Similarity - -# 1.Compute cosine similarity between two sentences. -sentences = ['如何更换花呗绑定银行卡', - '花呗更改绑定银行卡'] -corpus = [ - '花呗更改绑定银行卡', - '我什么时候开通了花呗', - '俄罗斯警告乌克兰反对欧盟协议', - '暴风雨掩埋了东北部;新泽西16英寸的降雪', - '中央情报局局长访问以色列叙利亚会谈', - '人在巴基斯坦基地的炸弹袭击中丧生', -] -model = Similarity(model_name_or_path="shibing624/text2vec-base-chinese") -print(model) -similarity_score = model.similarity(sentences[0], sentences[1]) -print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}") - -print('-' * 50 + '\n') -# 2.Compute similarity between two list -similarity_scores = model.similarity(sentences, corpus) -print(similarity_scores.numpy()) -for i in range(len(sentences)): - for j in range(len(corpus)): - print(f"{sentences[i]} vs {corpus[j]}, score: {similarity_scores.numpy()[i][j]:.4f}") - -print('-' * 50 + '\n') -# 3.Semantic Search -model.add_corpus(corpus) -res = model.most_similar(queries=sentences, topn=3) -print(res) -for q_id, c in res.items(): - print('query:', sentences[q_id]) - print("search top 3:") - for corpus_id, s in c.items(): - print(f'\t{model.corpus[corpus_id]}: {s:.4f}') -``` - -output: - -```shell -如何更换花呗绑定银行卡 vs 花呗更改绑定银行卡, score: 0.8551 -... - -如何更换花呗绑定银行卡 vs 花呗更改绑定银行卡, score: 0.8551 -如何更换花呗绑定银行卡 vs 我什么时候开通了花呗, score: 0.7212 -如何更换花呗绑定银行卡 vs 俄罗斯警告乌克兰反对欧盟协议, score: 0.1450 -如何更换花呗绑定银行卡 vs 暴风雨掩埋了东北部;新泽西16英寸的降雪, score: 0.2167 -如何更换花呗绑定银行卡 vs 中央情报局局长访问以色列叙利亚会谈, score: 0.2517 -如何更换花呗绑定银行卡 vs 人在巴基斯坦基地的炸弹袭击中丧生, score: 0.0809 -花呗更改绑定银行卡 vs 花呗更改绑定银行卡, score: 1.0000 -花呗更改绑定银行卡 vs 我什么时候开通了花呗, score: 0.6807 -花呗更改绑定银行卡 vs 俄罗斯警告乌克兰反对欧盟协议, score: 0.1714 -花呗更改绑定银行卡 vs 暴风雨掩埋了东北部;新泽西16英寸的降雪, score: 0.2162 -花呗更改绑定银行卡 vs 中央情报局局长访问以色列叙利亚会谈, score: 0.2728 -花呗更改绑定银行卡 vs 人在巴基斯坦基地的炸弹袭击中丧生, score: 0.1279 - -query: 如何更换花呗绑定银行卡 -search top 3: - 花呗更改绑定银行卡: 0.8551 - 我什么时候开通了花呗: 0.7212 - 中央情报局局长访问以色列叙利亚会谈: 0.2517 -``` - -> 余弦`score`的值范围[-1, 1],值越大,表示该query与corpus的文本越相似。 - -#### 多语言文本语义相似度计算和匹配搜索 +#### 多语言文本相似度计算和文本搜索 -多语言:包括中、英、韩、日、德、意等多国语言 +使用[shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)模型,支持中、英、韩、日、德、意等多国语言 example: [examples/text_semantic_search_multilingual_demo.py](https://github.com/shibing624/similarities/blob/main/examples/text_semantic_search_multilingual_demo.py) -### 3. 快速近似文本语义匹配搜索 +### 3. 近似文本搜索 支持Annoy、Hnswlib的近似语义匹配搜索,常用于百万数据集的匹配搜索任务。 example: [examples/fast_text_semantic_search_demo.py](https://github.com/shibing624/similarities/blob/main/examples/fast_text_semantic_search_demo.py) -### 4. 基于字面的文本相似度计算和匹配搜索 +### 4. 基于字面的文本相似度计算和文本搜索 支持同义词词林(Cilin)、知网Hownet、词向量(WordEmbedding)、Tfidf、SimHash、BM25等算法的相似度计算和字面匹配搜索,常用于文本匹配冷启动。 example: [examples/literal_text_semantic_search_demo.py](https://github.com/shibing624/similarities/blob/main/examples/literal_text_semantic_search_demo.py) -```python -from similarities import SimHashSimilarity, TfidfSimilarity, BM25Similarity, \ - WordEmbeddingSimilarity, CilinSimilarity, HownetSimilarity - -text1 = "如何更换花呗绑定银行卡" -text2 = "花呗更改绑定银行卡" - -corpus = [ - '花呗更改绑定银行卡', - '我什么时候开通了花呗', - '俄罗斯警告乌克兰反对欧盟协议', - '暴风雨掩埋了东北部;新泽西16英寸的降雪', - '中央情报局局长访问以色列叙利亚会谈', - '人在巴基斯坦基地的炸弹袭击中丧生', -] - -queries = [ - '我的花呗开通了?', - '乌克兰被俄罗斯警告' -] -m = TfidfSimilarity() -print(text1, text2, ' sim score: ', m.similarity(text1, text2)) - -m.add_corpus(corpus) -res = m.most_similar(queries, topn=3) -print('sim search: ', res) -for q_id, c in res.items(): - print('query:', queries[q_id]) - print("search top 3:") - for corpus_id, s in c.items(): - print(f'\t{m.corpus[corpus_id]}: {s:.4f}') -``` - -output: - -```shell -如何更换花呗绑定银行卡 花呗更改绑定银行卡 sim score: 0.8203384355246909 - -sim search: {0: {2: 0.9999999403953552, 1: 0.43930041790008545, 0: 0.0}, 1: {0: 0.7380483150482178, 1: 0.0, 2: 0.0}} -query: 我的花呗开通了? -search top 3: - 我什么时候开通了花呗: 1.0000 - 花呗更改绑定银行卡: 0.4393 - 俄罗斯警告乌克兰反对欧盟协议: 0.0000 -... -``` +### 5. 图像相似度计算和图片搜索 -### 5. 图像相似度计算和匹配搜索 - -支持[CLIP](similarities/imagesim.py)、pHash、SIFT等算法的图像相似度计算和匹配搜索。 +支持CLIP、pHash、SIFT等算法的图像相似度计算和匹配搜索,中文 CLIP 模型支持图搜图,文搜图、还支持中英文图文互搜。 example: [examples/image_semantic_search_demo.py](https://github.com/shibing624/similarities/blob/main/examples/image_semantic_search_demo.py) -```python -import sys -import glob -from PIL import Image - -sys.path.append('..') -from similarities import ImageHashSimilarity, SiftSimilarity, ClipSimilarity - - -def sim_and_search(m): - print(m) - # similarity - sim_scores = m.similarity(imgs1, imgs2) - print('sim scores: ', sim_scores) - for (idx, i), j in zip(enumerate(image_fps1), image_fps2): - s = sim_scores[idx] if isinstance(sim_scores, list) else sim_scores[idx][idx] - print(f"{i} vs {j}, score: {s:.4f}") - # search - m.add_corpus(corpus_imgs) - queries = imgs1 - res = m.most_similar(queries, topn=3) - print('sim search: ', res) - for q_id, c in res.items(): - print('query:', image_fps1[q_id]) - print("search top 3:") - for corpus_id, s in c.items(): - print(f'\t{m.corpus[corpus_id].filename}: {s:.4f}') - print('-' * 50 + '\n') - -image_fps1 = ['data/image1.png', 'data/image3.png'] -image_fps2 = ['data/image12-like-image1.png', 'data/image10.png'] -imgs1 = [Image.open(i) for i in image_fps1] -imgs2 = [Image.open(i) for i in image_fps2] -corpus_fps = glob.glob('data/*.jpg') + glob.glob('data/*.png') -corpus_imgs = [Image.open(i) for i in corpus_fps] - -# 2. image and image similarity score -sim_and_search(ClipSimilarity()) # the best result -sim_and_search(ImageHashSimilarity(hash_function='phash')) -sim_and_search(SiftSimilarity()) -``` - -output: - -```shell -Similarity: ClipSimilarity, matching_model: CLIPModel -sim scores: tensor([[0.9580, 0.8654], - [0.6558, 0.6145]]) - -data/image1.png vs data/image12-like-image1.png, score: 0.9580 -data/image3.png vs data/image10.png, score: 0.6145 - -sim search: {0: {6: 0.9999999403953552, 0: 0.9579654932022095, 4: 0.9326782822608948}, 1: {8: 0.9999997615814209, 4: 0.6729235649108887, 0: 0.6558331847190857}} - -query: data/image1.png -search top 3: - data/image1.png: 1.0000 - data/image12-like-image1.png: 0.9580 - data/image8-like-image1.png: 0.9327 -``` - ![image_sim](docs/image_sim.png) -### 6. 图文互搜 -CLIP 模型不仅支持以图搜图,还支持中英文图文互搜: -```python -import sys -import glob -from PIL import Image -sys.path.append('..') -from similarities import ImageHashSimilarity, SiftSimilarity, ClipSimilarity - -m = ClipSimilarity() -print(m) -# similarity score between text and image -image_fps = ['data/image3.png', # yellow flower image - 'data/image1.png'] # tiger image -texts = ['a yellow flower', '老虎'] -imgs = [Image.open(i) for i in image_fps] -sim_scores = m.similarity(imgs, texts) - -print('sim scores: ', sim_scores) -for (idx, i), j in zip(enumerate(image_fps), texts): - s = sim_scores[idx][idx] - print(f"{i} vs {j}, score: {s:.4f}") -``` -output: -```shell -sim scores: tensor([[0.3220, 0.2409], - [0.1677, 0.2959]]) -data/image3.png vs a yellow flower, score: 0.3220 -data/image1.png vs 老虎, score: 0.2112 -``` - -# Contact +## Contact - Issue(建议) :[![GitHub issues](https://img.shields.io/github/issues/shibing624/similarities.svg)](https://github.com/shibing624/similarities/issues) @@ -387,7 +117,7 @@ data/image1.png vs 老虎, score: 0.2112 -# Citation +## Citation 如果你在研究中使用了similarities,请按如下格式引用: @@ -408,11 +138,11 @@ BibTeX: } ``` -# License +## License 授权协议为 [The Apache License 2.0](/LICENSE),可免费用做商业用途。请在产品说明中附加similarities的链接和授权协议。 -# Contribute +## Contribute 项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点: @@ -421,11 +151,15 @@ BibTeX: 之后即可提交PR。 -# Reference +## Acknowledgements - [A Simple but Tough-to-Beat Baseline for Sentence Embeddings[Sanjeev Arora and Yingyu Liang and Tengyu Ma, 2017]](https://openreview.net/forum?id=SyK00v5xx) -- [liuhuanyong/SentenceSimilarity](https://github.com/liuhuanyong/SentenceSimilarity) -- [shibing624/text2vec](https://github.com/shibing624/text2vec) -- [qwertyforce/image_search](https://github.com/qwertyforce/image_search) +- [https://github.com/liuhuanyong/SentenceSimilarity](https://github.com/liuhuanyong/SentenceSimilarity) +- [https://github.com/qwertyforce/image_search](https://github.com/qwertyforce/image_search) - [ImageHash - Official Github repository](https://github.com/JohannesBuchner/imagehash) -- [openai/CLIP](https://github.com/openai/CLIP) +- [https://github.com/openai/CLIP](https://github.com/openai/CLIP) +- [https://github.com/OFA-Sys/Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP) +- [https://github.com/UKPLab/sentence-transformers](https://github.com/UKPLab/sentence-transformers) +- [https://github.com/rom1504/clip-retrieval](https://github.com/rom1504/clip-retrieval) + +Thanks for their great work! diff --git a/examples/data/image_info.csv b/examples/data/image_info.csv new file mode 100644 index 0000000..76de10a --- /dev/null +++ b/examples/data/image_info.csv @@ -0,0 +1,10 @@ +image_path,text,caption,format +data/image1.png,image1,老虎,png +data/image11-like-image10.png,image11-like-image10,狮子,png +data/image13-like-image1.png,image13-like-image1,果盘,png +data/image5.png,image5,男人,png +data/image8-like-image1.png,image8-like-image1,老虎下山,png +data/image10.png,image10,母狮子,png +data/image12-like-image1.png,image12-like-image1,老虎特写,png +data/image3.png,image3,黄花,png +data/image6-like-image5.png,image6-like-image5,女人,png diff --git a/examples/data/toy_corpus/corpus_100.txt b/examples/data/toy_corpus/corpus_100.txt new file mode 100644 index 0000000..c336123 --- /dev/null +++ b/examples/data/toy_corpus/corpus_100.txt @@ -0,0 +1,103 @@ +有人在煎肉。 +花呗更改绑定的银行卡 +花呗是啥 +花呗改绑定银行卡 +在伊拉克的暴力袭击中有41人死亡,22人受伤 +莫斯科反普京抗议者形成人链 +黑白相间的狗在河里游泳。 +一男一女在宝莱坞跳舞。 +坐在婴儿床上的婴儿,伸手去摸一个大男孩的脸。 +聋子顾客告星巴克,说他们被嘲笑了 +曼谷警方向抗议者发射催泪瓦斯 +中国外交部发言人姜瑜表示,网络黑客是一个全球性问题。 +科学家们认为,星尘捕获了数千颗尘埃。 +吉尔罗伊警方和联邦调查局特工称Gehring很合作,但周六表示,他没有透露孩子们发生了什么事。 +“我真的很喜欢他,现在也是,”科恩·阿隆昨天告诉“先驱报”。 +范表示,该委员会主要关注的是未来在台湾投资的期限。 +一只小猴子抓它自己。 +7人死于基督教学校枪击案 +英国坎特伯雷大主教转行剑桥 +猴子拉狗的尾巴。 +那人用棍子打了另一个人。 +约旦当选为联合国安理会成员 +滑板运动员试图在低矮的墙边玩一个把戏。 +西班牙失业率创新高 +周一,一家分裂的最高法院裁定,国会可以强制国家公共图书馆为电脑配备反色情过滤器。 +格拉斯哥直升机死亡人数上升到9人 +一个女人和一条狗站在草地上。 +一只棕色的狗正把爪子放在笔记本电脑的键盘上。 +发言人说:“自十一月起,我们已与警方通力合作。 +台风菲托登陆中国东部 +埃及对西奈武装分子发动进攻 +持枪歹徒在巴基斯坦杀害11名外国游客 +丹尼尔斯说,在今晚的每人2000美元的总统筹款活动上,他不知道布什会对他说些什么。 +从窗户射出的阳光下的盆栽植物。 +北约士兵在阿富汗袭击中丧生 +一个穿着红色夹克的男人拿着他的自行车拍照。 +旋风在印度留下毁灭的痕迹。 +两条狗在草地上奔跑。 +一个人弹钢琴。 +那只黑狗躺在草地上。 +巴西世界杯五颜六色的支持 +Key说marae错过了一次机会 +“聪明但胡思乱想”变成了“药物治疗”。 +一个男人正死在手术室里。 +意大利新政府达成协议 +两只大狗在草地上奔跑。 +曼德拉的病情已经“好转”了。 +顺便说一句,他当时有适当的延期。 +劳工部分析师认为,来自企业调查的工资统计数据提供了更准确的经济图景,因为调查数据基于更大的样本。 +一只狗在车里。 +一个年轻的女孩骑着一匹棕色的马。 +所以回答问题。 +獾在挖洞。 +一个男人和一只狗在打拔河。 +韩国驱逐老年美国人 +一个人在切东西。 +一只绿眼睛的灰猫看着相机。 +一些女性还担心,很多女性为了追求身体的完美而转向整容手术。 +思科高管表示,他们对13亿美元的现金流和净利润的增加感到鼓舞,但希望能出现反弹。 +猫躺在地毯上的黑白图像。 +美国股市下跌,希腊谈木材;苹果下跌 +当不知情的计算机用户打开包含“谢谢!”、“Re:Details”或“Re:That Movie”等常见标题的电子邮件中的文件附件时,Sobig.F会传播。 +可能是很多事情。 +一个骑马的人骑着一匹白马。 +该公司通过新的大型机将.NET和J2EE支持添加到其企业应用程序环境(EnterpriseApplicationEnvironment,EAE)中。 +乡村音乐电台KKCS暂停了两名播放迪克西小鸡歌曲的音乐主持人的职务,这违反了一名乐队成员批评布什总统后实施的禁令。 +施洗约翰从出生时就是拿撒勒人。 +Det总督察NormanMcKinlay说,“有证据表明在这一地区有一具或多具尸体”。 +俄罗斯人离开叙利亚进入黎巴嫩 +一组背景中有树木的驳船。 +一个女人和两个戴着帽子的男人在外面社交。 +伊朗表示核谈判仍存在严重问题 +澳大利亚电视台记者彼得·劳埃德在新加坡面临3项新的药物指控。 +以色列批准在和谈前释放第一批巴勒斯坦囚犯 +在南卡罗来纳州海滨汽车旅馆枪击事件中,3人死亡,1人受伤 +中国煤矿事故10人死亡 +中国对台风卡尔梅吉高度戒备 +在海浪上骑着冲浪板的人。 +袋鼠在吃东西。 +我们都相信我们正在得到我们所期望的。 +红色双层巴士载客。 +CCAG支持罗兰德在2002年州长选举中的对手比尔·库里(Bill Curry)。 +如果你想买十亿美元的XYZ公司。普通股,谁在乎呢? +虽然非洲其他一些地区被用作恐怖组织的集结地,但马拉维以前并不是对基地组织进行调查的主要重点。 +人们聚集在伯利恒过圣诞节。 +两个人在打架。 +1月至6月,该消费集团创造了44.7亿美元的利润和200亿美元的收入,占花旗集团利润和收入的53%。 +上周,该公司的法律代表开始与原告的律师会面,讨论一项和解协议。 +俄罗斯经济增长乏力,普京宣誓就职 +一张白色和紫色摩托车的特写照片。 +结果,24名球员在第一轮比赛中超过了标准杆。 +巴拿马外交部公报说,该工作组将促进技术转让,促进生物燃料的生产和消费。 +肉被扔进平底锅里。 +食品价格上涨引起伊朗的担忧 +里德列举了债务上限谈判取得的“巨大进展”。 +一只棕色的狗在水中涉水。 +2名巴勒斯坦人在以色列空袭中丧生 +科罗拉多枪击案嫌疑人接受了精神病医生的治疗 +松散的变化始于一部虚构的作品。 +雅虎同意以11亿美元现金收购Tumblr +爱沙尼亚官员表示,导致爱沙尼亚政府网站暂时关闭的一些网络攻击来自俄罗斯政府的计算机,包括俄罗斯总统弗拉基米尔·普京(Vladimir Putin)的办公室。 +一个男人在弹吉他和唱歌。 +克里夫兰骑士赢得了选秀詹姆斯的权利,赢得了周四晚上的年度彩票。 diff --git a/examples/faiss_bert_search_client_demo.py b/examples/faiss_bert_search_client_demo.py new file mode 100644 index 0000000..d7649dd --- /dev/null +++ b/examples/faiss_bert_search_client_demo.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: Use Faiss for text similarity search demo +""" + +import sys + +sys.path.append('..') +from similarities import BertClient + + +def main(): + # Client + client = BertClient('http://0.0.0.0:8001') + # 获取嵌入 + text_input = "This is a sample text." + emb = client.get_emb(text_input) + print(f"Embedding for '{text_input}': {emb}") + # 获取相似度 + similarity = client.get_similarity("This is a sample text.", "This is another sample text.") + print(f"Similarity between item1 and item2: {similarity}") + # 搜索 + search_input = "This is a sample text." + search_results = client.search(search_input) + print(f"Search results for '{search_input}': {search_results}") + + +if __name__ == '__main__': + main() diff --git a/examples/faiss_bert_search_server_demo.py b/examples/faiss_bert_search_server_demo.py new file mode 100644 index 0000000..8ce3b1f --- /dev/null +++ b/examples/faiss_bert_search_server_demo.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: Use Faiss for text similarity search demo +""" + +import sys + +sys.path.append('..') +from similarities import bert_embedding, bert_index, bert_filter, bert_server + + +def main(): + # Build embedding + bert_embedding( + input_dir='data/toy_corpus/', + embeddings_dir='tmp_embeddings_dir/', + embeddings_name='emb.npy', + corpus_file='tmp_data_dir/corpus.npy', + model_name="shibing624/text2vec-base-chinese", + batch_size=12, + device=None, + normalize_embeddings=True, + ) + + # Build index + bert_index( + embeddings_dir='tmp_embeddings_dir/', + index_dir="tmp_index_dir/", + index_name="faiss.index", + max_index_memory_usage="1G", + current_memory_available="2G", + use_gpu=False, + nb_cores=None, + ) + + # Filter(search) support multi query, batch search + sentences = ['如何更换花呗绑定银行卡', '花呗更改绑定银行卡'] + bert_filter( + queries=sentences, + output_file=f"tmp_outputs/result.json", + model_name="shibing624/text2vec-base-chinese", + index_dir='tmp_index_dir/', + index_name="faiss.index", + corpus_file="tmp_data_dir/corpus.npy", + num_results=5, + threshold=None, + device=None, + ) + + # Server + bert_server( + model_name="shibing624/text2vec-base-chinese", + index_dir='tmp_index_dir/', + index_name="faiss.index", + corpus_file="tmp_data_dir/corpus.npy", + num_results=5, + threshold=None, + device=None, + port=8001, + ) + + +if __name__ == '__main__': + main() diff --git a/examples/faiss_clip_search_client_demo.py b/examples/faiss_clip_search_client_demo.py new file mode 100644 index 0000000..5d9ac53 --- /dev/null +++ b/examples/faiss_clip_search_client_demo.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: Use Faiss for image similarity search demo +""" + +import sys + +import numpy as np + +sys.path.append('..') +from similarities import ClipClient, ClipItem + + +def main(): + # Client + client = ClipClient('http://0.0.0.0:8002') + + # 获取嵌入,支持获取文本嵌入、图片嵌入 + text_input = "This is a sample text." + emb = client.get_emb(text=text_input) + print(f"Embedding for '{text_input}': {emb}") + # input image + image_input = "data/image1.png" + emb = client.get_emb(image=image_input) + print(f"Embedding for '{image_input}': {emb}") + + # 获取相似度,支持计算图文相似度、图片相似度 + item1 = ClipItem(image="data/image1.png") + item2 = ClipItem(text="老虎") + similarity = client.get_similarity(item1, item2) + print(f"Similarity between item1 and item2: {similarity}") + + # 搜索 + # 1. 文搜图 + search_input = "This is a sample text." + search_results = client.search(text=search_input) + print(f"Search results for '{search_input}': {search_results}") + # 2. 图搜图 + search_input = "data/image1.png" + search_results = client.search(image=search_input) + print(f"Search results for '{search_input}': {search_results}") + # 3. 向量搜图 + search_results = client.search(emb=np.random.randn(512).tolist()) + print(f"Search results for emb search: {search_results}") + + +if __name__ == '__main__': + main() diff --git a/examples/faiss_clip_search_server_demo.py b/examples/faiss_clip_search_server_demo.py new file mode 100644 index 0000000..389147b --- /dev/null +++ b/examples/faiss_clip_search_server_demo.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: Use Faiss for image similarity search demo +""" + +import sys + +import numpy as np + +sys.path.append('..') +from similarities import clip_embedding, clip_index, clip_filter, clip_server + + +def main(): + # Build embedding + clip_embedding( + input_data_or_path='data/image_info.csv', + columns=None, + header=0, + delimiter=',', + image_embeddings_dir='tmp_image_embeddings_dir/', + text_embeddings_dir=None, + embeddings_name='emb.npy', + corpus_file='tmp_data_dir/corpus.csv', + model_name="OFA-Sys/chinese-clip-vit-base-patch16", + batch_size=12, + enable_image=True, + enabel_text=False, + device=None, + normalize_embeddings=True, + ) + + # Build index + clip_index( + image_embeddings_dir='tmp_image_embeddings_dir/', + text_embeddings_dir=None, + image_index_dir='tmp_image_index_dir/', + text_index_dir=None, + index_name='faiss.index', + max_index_memory_usage='1G', + current_memory_available='2G', + use_gpu=False, + nb_cores=None, + ) + + # Filter(search) 文搜图, support multi query, batch search + sentences = ['老虎', '花朵'] + clip_filter( + texts=sentences, + output_file=f"tmp_image_outputs/result_txt.json", + model_name="OFA-Sys/chinese-clip-vit-base-patch16", + index_dir='tmp_image_index_dir/', + index_name="faiss.index", + corpus_file="tmp_data_dir/corpus.csv", + num_results=5, + threshold=None, + device=None, + ) + + # Filter(search) 图搜图, support multi query, batch search + images = ['data/image1.png', 'data/image3.png'] + clip_filter( + images=images, + output_file=f"tmp_image_outputs/result_img.json", + model_name="OFA-Sys/chinese-clip-vit-base-patch16", + index_dir='tmp_image_index_dir/', + index_name="faiss.index", + corpus_file="tmp_data_dir/corpus.csv", + num_results=5, + threshold=None, + device=None, + ) + + # Filter(search) 向量搜图, support multi query, batch search + clip_filter( + embeddings=np.random.randn(1, 512), + output_file=f"tmp_image_outputs/result_img.json", + model_name="OFA-Sys/chinese-clip-vit-base-patch16", + index_dir='tmp_image_index_dir/', + index_name="faiss.index", + corpus_file="tmp_data_dir/corpus.csv", + num_results=5, + threshold=None, + device=None, + ) + + # Start Server + clip_server( + model_name="OFA-Sys/chinese-clip-vit-base-patch16", + index_dir='tmp_image_index_dir/', + index_name="faiss.index", + corpus_file="tmp_data_dir/corpus.csv", + num_results=5, + threshold=None, + device=None, + port=8002, + debug=True, + ) + + +if __name__ == '__main__': + main() diff --git a/examples/fast_text_semantic_search_demo.py b/examples/fast_text_semantic_search_demo.py index 898077e..bb91cf7 100644 --- a/examples/fast_text_semantic_search_demo.py +++ b/examples/fast_text_semantic_search_demo.py @@ -3,11 +3,13 @@ @author:XuMing(xuming624@qq.com) @description: Fast similarity search demo """ + +import os import sys sys.path.append('..') -from similarities.fastsim import AnnoySimilarity -from similarities.fastsim import HnswlibSimilarity +from similarities import AnnoySimilarity +from similarities import HnswlibSimilarity sentences = ['如何更换花呗绑定银行卡', '花呗更改绑定银行卡'] @@ -30,12 +32,13 @@ def annoy_demo(): print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}") model.add_corpus(corpus) model.build_index() - model.save_index('annoy_model.index') + index_file = 'annoy_model.index' + model.save_index(index_file) print(model.most_similar("men喜欢这首歌")) # Semantic Search batch del model model = AnnoySimilarity() - model.load_index('annoy_model.index') + model.load_index(index_file) print(model.most_similar("men喜欢这首歌")) queries = ["如何更换花呗绑定银行卡", "men喜欢这首歌"] res = model.most_similar(queries, topn=3) @@ -46,7 +49,7 @@ def annoy_demo(): for corpus_id, s in c.items(): print(f'\t{model.corpus[corpus_id]}: {s:.4f}') - # os.remove('annoy_model.bin') + os.remove(index_file) print('-' * 50 + '\n') @@ -59,12 +62,13 @@ def hnswlib_demo(): print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}") model.add_corpus(corpus) model.build_index() - model.save_index('hnsw_model.index') + index_file = 'hnsw_model.index' + model.save_index(index_file) print(model.most_similar("men喜欢这首歌")) # Semantic Search batch del model model = HnswlibSimilarity() - model.load_index('hnsw_model.index') + model.load_index(index_file) print(model.most_similar("men喜欢这首歌")) queries = ["如何更换花呗绑定银行卡", "men喜欢这首歌"] res = model.most_similar(queries, topn=3) @@ -75,7 +79,7 @@ def hnswlib_demo(): for corpus_id, s in c.items(): print(f'\t{model.corpus[corpus_id]}: {s:.4f}') - # os.remove('hnsw_model.bin') + os.remove(index_file) print('-' * 50 + '\n') diff --git a/examples/image_semantic_serach_demo.py b/examples/image_semantic_serach_demo.py index ed9a7ec..0e11f22 100644 --- a/examples/image_semantic_serach_demo.py +++ b/examples/image_semantic_serach_demo.py @@ -34,7 +34,7 @@ def sim_and_search(m): def clip_demo(): - m = ClipSimilarity() + m = ClipSimilarity(model_name_or_path="openai/clip-vit-base-patch32") print(m) # similarity score between text and image image_fps = [ diff --git a/examples/search_gradio_demo.py b/examples/search_gradio_demo.py index a1eca53..1ca7674 100644 --- a/examples/search_gradio_demo.py +++ b/examples/search_gradio_demo.py @@ -5,9 +5,9 @@ """ import gradio as gr -from similarities import Similarity +from similarities import BertSimilarity -sim_model = Similarity() +sim_model = BertSimilarity() def load_file(path): diff --git a/examples/similarity_gradio_demo.py b/examples/similarity_gradio_demo.py index 1b70e1b..4d265d9 100644 --- a/examples/similarity_gradio_demo.py +++ b/examples/similarity_gradio_demo.py @@ -4,9 +4,9 @@ @description: pip install gradio """ import gradio as gr -from similarities import Similarity +from similarities import BertSimilarity -sim_model = Similarity() +sim_model = BertSimilarity() def ai_text(sentence1, sentence2): diff --git a/examples/text_semantic_search_demo.py b/examples/text_semantic_search_demo.py index b46bdde..58f872f 100644 --- a/examples/text_semantic_search_demo.py +++ b/examples/text_semantic_search_demo.py @@ -6,7 +6,7 @@ import sys sys.path.append('..') -from similarities import Similarity +from similarities import BertSimilarity # 1.Compute cosine similarity between two sentences. sentences = ['如何更换花呗绑定银行卡', @@ -19,7 +19,7 @@ '中央情报局局长访问以色列叙利亚会谈', '人在巴基斯坦基地的炸弹袭击中丧生', ] -model = Similarity(model_name_or_path="shibing624/text2vec-base-chinese") +model = BertSimilarity(model_name_or_path="shibing624/text2vec-base-chinese") print(model) similarity_score = model.similarity(sentences[0], sentences[1]) print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}") @@ -42,3 +42,6 @@ print("search top 3:") for corpus_id, s in id_score_dict.items(): print(f'\t{model.corpus[corpus_id]}: {s:.4f}') + +print('-' * 50 + '\n') +print(model.search(sentences[0], topn=3)) diff --git a/examples/text_semantic_search_multilingual_demo.py b/examples/text_semantic_search_multilingual_demo.py index 1d10c30..16b909b 100644 --- a/examples/text_semantic_search_multilingual_demo.py +++ b/examples/text_semantic_search_multilingual_demo.py @@ -7,7 +7,7 @@ import sys sys.path.append('..') -from similarities import Similarity +from similarities import BertSimilarity # Two lists of sentences sentences1 = [ @@ -26,7 +26,7 @@ '敏捷的棕色狐狸跳过了懒狗', ] -model = Similarity(model_name_or_path="shibing624/text2vec-base-multilingual") +model = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual") # 使用的是多语言文本匹配模型 scores = model.similarity(sentences1, sentences2) print('1:use Similarity compute cos scores\n') @@ -53,12 +53,12 @@ ] model.add_corpus(corpus) -model.save_index('en_corpus_emb.json') +model.save_embeddings('en_corpus_emb.json') res = model.most_similar(queries=sentences1, topn=3) print(res) del model -model = Similarity(model_name_or_path="shibing624/text2vec-base-multilingual") -model.load_index('en_corpus_emb.json') +model = BertSimilarity(model_name_or_path="shibing624/text2vec-base-multilingual") +model.load_embeddings('en_corpus_emb.json') res = model.most_similar(queries=sentences1, topn=3) print(res) for q_id, c in res.items(): diff --git a/examples/text_similarity_demo.py b/examples/text_similarity_demo.py index 87a282c..5812190 100644 --- a/examples/text_similarity_demo.py +++ b/examples/text_similarity_demo.py @@ -6,8 +6,8 @@ import sys sys.path.append('..') -from similarities import Similarity +from similarities import BertSimilarity -m = Similarity() +m = BertSimilarity() r = m.similarity('如何更换花呗绑定银行卡', '花呗更改绑定银行卡') print(f"similarity score: {float(r)}") diff --git a/requirements.txt b/requirements.txt index a5f5f88..6a40e34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -text2vec>=1.1.5 +text2vec>=1.2.8 jieba>=0.39 loguru transformers Pillow -hnswlib -annoy opencv-python +autofaiss +fire \ No newline at end of file diff --git a/setup.py b/setup.py index 9a35f46..e18f4ae 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ license="Apache License 2.0", zip_safe=False, python_requires=">=3.6.0", + entry_points={"console_scripts": ["similarities = similarities.cli:main"]}, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", @@ -37,14 +38,13 @@ ], keywords='similarities,Chinese Text Similarity Calculation Tool,similarity,word2vec', install_requires=[ - "text2vec>=1.1.5", - "transformers", + "text2vec>=1.2.9", "jieba>=0.39", "loguru", "Pillow", - # "hnswlib", - # "opencv-python", - # "annoy", + "fire", + "autofaiss", + "transformers", ], packages=find_packages(), ) diff --git a/similarities/__init__.py b/similarities/__init__.py index a7f2fe5..63667cd 100644 --- a/similarities/__init__.py +++ b/similarities/__init__.py @@ -2,15 +2,15 @@ """ @author:XuMing(xuming624@qq.com) @description: - -This package contains implementations of pairwise similarity queries. """ # bring classes directly into package namespace, to save some typing from similarities.version import __version__ -from similarities.similarity import Similarity -from similarities.fastsim import AnnoySimilarity, HnswlibSimilarity -from similarities.literalsim import ( +from similarities.bert_similarity import BertSimilarity +from similarities.bert_similarity import BertSimilarity as Similarity + +from similarities.fast_bert_similarity import AnnoySimilarity, HnswlibSimilarity +from similarities.literal_similarity import ( SimHashSimilarity, TfidfSimilarity, BM25Similarity, @@ -20,11 +20,16 @@ SameCharsSimilarity, SequenceMatcherSimilarity, ) -from similarities.imagesim import ( +from similarities.image_similarity import ( ImageHashSimilarity, - ClipSimilarity, SiftSimilarity, ) +from similarities.clip_similarity import ClipSimilarity +from similarities.clip_module import ClipModule from similarities.data_loader import SearchDataLoader from similarities import evaluation from similarities import utils +from similarities.faiss_bert_similarity import bert_embedding, bert_index, bert_filter, bert_server +from similarities.faiss_clip_similarity import clip_embedding, clip_index, clip_filter, clip_server +from similarities.faiss_bert_similarity import BertClient +from similarities.faiss_clip_similarity import ClipClient, ClipItem diff --git a/similarities/bert_similarity.py b/similarities/bert_similarity.py new file mode 100644 index 0000000..685fd38 --- /dev/null +++ b/similarities/bert_similarity.py @@ -0,0 +1,214 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: + +Compute similarity: +1. Compute the similarity between two sentences +2. Retrieves most similar sentence of a query against a corpus of documents. +""" + +import json +import os +from typing import List, Union, Dict + +import numpy as np +from loguru import logger +from text2vec import SentenceModel +from similarities.utils.util import cos_sim, semantic_search, dot_score + +from similarities.similarity import SimilarityABC + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" + + +class BertSimilarity(SimilarityABC): + """ + Sentence Similarity: + 1. Compute the similarity between two sentences + 2. Retrieves most similar sentence of a query against a corpus of documents. + + The index supports adding new documents dynamically. + """ + + def __init__( + self, + corpus: Union[List[str], Dict[str, str]] = None, + model_name_or_path="shibing624/text2vec-base-chinese", + device=None, + ): + """ + Initialize the similarity object. + :param model_name_or_path: Transformer model name or path, like: + 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'bert-base-uncased', 'bert-base-chinese', + 'shibing624/text2vec-base-chinese', ... + model in HuggingFace Model Hub and release from https://github.com/shibing624/text2vec + :param corpus: Corpus of documents to use for similarity queries. + :param device: Device (like 'cuda' / 'cpu') to use for the computation. + """ + if isinstance(model_name_or_path, str): + self.sentence_model = SentenceModel( + model_name_or_path, + device=device + ) + elif hasattr(model_name_or_path, "encode"): + self.sentence_model = model_name_or_path + else: + raise ValueError("model_name_or_path is transformers model name or path") + self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score} + self.corpus = {} + self.corpus_embeddings = [] + if corpus is not None: + self.add_corpus(corpus) + + def __len__(self): + """Get length of corpus.""" + return len(self.corpus) + + def __str__(self): + base = f"Similarity: {self.__class__.__name__}, matching_model: {self.sentence_model}" + if self.corpus: + base += f", corpus size: {len(self.corpus)}" + return base + + def get_sentence_embedding_dimension(self): + """ + Get the dimension of the sentence embeddings. + + Returns + ------- + int or None + The dimension of the sentence embeddings, or None if it cannot be determined. + """ + if hasattr(self.sentence_model, "get_sentence_embedding_dimension"): + return self.sentence_model.get_sentence_embedding_dimension() + else: + return getattr(self.sentence_model.bert.pooler.dense, "out_features", None) + + def add_corpus(self, corpus: Union[List[str], Dict[str, str]]): + """ + Extend the corpus with new documents. + :param corpus: corpus of documents to use for similarity queries. + :return: self.corpus, self.corpus embeddings + """ + new_corpus = {} + start_id = len(self.corpus) if self.corpus else 0 + for id, doc in enumerate(corpus): + if isinstance(corpus, list): + if doc not in self.corpus.values(): + new_corpus[start_id + id] = doc + else: + if doc not in self.corpus.values(): + new_corpus[id] = doc + self.corpus.update(new_corpus) + logger.info(f"Start computing corpus embeddings, new docs: {len(new_corpus)}") + corpus_embeddings = self.get_embeddings(list(new_corpus.values()), show_progress_bar=True).tolist() + if self.corpus_embeddings: + self.corpus_embeddings = self.corpus_embeddings + corpus_embeddings + else: + self.corpus_embeddings = corpus_embeddings + logger.info(f"Add {len(new_corpus)} docs, total: {len(self.corpus)}, emb len: {len(self.corpus_embeddings)}") + + def get_embeddings( + self, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = False, + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + ): + """ + Returns the embeddings for a batch of sentences. + """ + return self.sentence_model.encode( + sentences, + batch_size=batch_size, + show_progress_bar=show_progress_bar, + convert_to_numpy=convert_to_numpy, + convert_to_tensor=convert_to_tensor, + device=device, + normalize_embeddings=normalize_embeddings, + ) + + def similarity(self, a: Union[str, List[str]], b: Union[str, List[str]], score_function: str = "cos_sim"): + """ + Compute similarity between two texts. + :param a: list of str or str + :param b: list of str or str + :param score_function: function to compute similarity, default cos_sim + :return: similarity score, torch.Tensor, Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ + if score_function not in self.score_functions: + raise ValueError(f"score function: {score_function} must be either (cos_sim) for cosine similarity" + " or (dot) for dot product") + score_function = self.score_functions[score_function] + text_emb1 = self.get_embeddings(a) + text_emb2 = self.get_embeddings(b) + + return score_function(text_emb1, text_emb2) + + def distance(self, a: Union[str, List[str]], b: Union[str, List[str]]): + """Compute cosine distance between two texts.""" + return 1 - self.similarity(a, b) + + def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10, + score_function: str = "cos_sim"): + """ + Find the topn most similar texts to the queries against the corpus. + It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries. + :param queries: str or list of str + :param topn: int + :param score_function: function to compute similarity, default cos_sim + :return: Dict[str, Dict[str, float]], {query_id: {corpus_id: similarity_score}, ...} + """ + if isinstance(queries, str) or not hasattr(queries, '__len__'): + queries = [queries] + if isinstance(queries, list): + queries = {id: query for id, query in enumerate(queries)} + if score_function not in self.score_functions: + raise ValueError(f"score function: {score_function} must be either (cos_sim) for cosine similarity" + " or (dot) for dot product") + score_function = self.score_functions[score_function] + result = {qid: {} for qid, query in queries.items()} + queries_ids_map = {i: id for i, id in enumerate(list(queries.keys()))} + queries_texts = list(queries.values()) + queries_embeddings = self.get_embeddings(queries_texts, convert_to_tensor=True) + corpus_embeddings = np.array(self.corpus_embeddings, dtype=np.float32) + all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn, score_function=score_function) + for idx, hits in enumerate(all_hits): + for hit in hits[0:topn]: + result[queries_ids_map[idx]][hit['corpus_id']] = hit['score'] + + return result + + def save_embeddings(self, emb_path: str = "corpus_emb.json"): + """ + Save corpus embeddings to json file. + :param emb_path: json file path + :return: + """ + corpus_emb = {id: {"doc": self.corpus[id], "doc_emb": emb} for id, emb in + zip(self.corpus.keys(), self.corpus_embeddings)} + with open(emb_path, "w", encoding="utf-8") as f: + json.dump(corpus_emb, f, ensure_ascii=False) + logger.debug(f"Save corpus embeddings to file: {emb_path}.") + + def load_embeddings(self, emb_path: str = "corpus_emb.json"): + """ + Load corpus embeddings from json file. + :param emb_path: json file path + :return: list of corpus embeddings, dict of corpus ids map, dict of corpus + """ + try: + with open(emb_path, "r", encoding="utf-8") as f: + corpus_emb = json.load(f) + corpus_embeddings = [] + for id, corpus_dict in corpus_emb.items(): + self.corpus[int(id)] = corpus_dict["doc"] + corpus_embeddings.append(corpus_dict["doc_emb"]) + self.corpus_embeddings = corpus_embeddings + except (IOError, json.JSONDecodeError): + logger.error("Error: Could not load corpus embeddings from file.") diff --git a/similarities/cli.py b/similarities/cli.py new file mode 100644 index 0000000..a7f3e00 --- /dev/null +++ b/similarities/cli.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: cli entry point +""" + +import fire + +from similarities.faiss_bert_similarity import bert_embedding, bert_index, bert_filter, bert_server +from similarities.faiss_clip_similarity import clip_embedding, clip_index, clip_filter, clip_server + + +def main(): + """Main entry point""" + + fire.Fire( + { + "bert_embedding": bert_embedding, + "bert_index": bert_index, + "bert_filter": bert_filter, + "bert_server": bert_server, + "clip_embedding": clip_embedding, + "clip_index": clip_index, + "clip_filter": clip_filter, + "clip_server": clip_server, + } + ) + + +if __name__ == "__main__": + main() diff --git a/similarities/clip_model.py b/similarities/clip_module.py similarity index 82% rename from similarities/clip_model.py rename to similarities/clip_module.py index 8485885..c089a78 100644 --- a/similarities/clip_model.py +++ b/similarities/clip_module.py @@ -13,10 +13,8 @@ from tqdm import trange from transformers import ChineseCLIPProcessor, ChineseCLIPModel, CLIPProcessor, CLIPModel -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -class CLIPModel(nn.Module): +class ClipModule(nn.Module): """ CLIP model for text and image embeddings @@ -25,22 +23,37 @@ class CLIPModel(nn.Module): chinese model url: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16 english model url: https://huggingface.co/openai/clip-vit-base-patch32 processor_name: str, default None + device: str, default None + is_chinese_model: bool, default None, if None, auto detect by model_name """ - def __init__(self, model_name: str = "OFA-Sys/chinese-clip-vit-base-patch16", processor_name=None): - super(CLIPModel, self).__init__() + def __init__( + self, + model_name: str = "OFA-Sys/chinese-clip-vit-base-patch16", + processor_name: str = None, + device: str = None, + is_chinese_model: bool = None, + ): + super(ClipModule, self).__init__() + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device self.model_name = model_name if processor_name is None: processor_name = model_name - if 'chinese' in model_name: - self.processor = ChineseCLIPProcessor.from_pretrained(processor_name) + if is_chinese_model is None: + is_chinese_model = 'chinese' in model_name + self.is_chinese_model = is_chinese_model + if is_chinese_model: self.model = ChineseCLIPModel.from_pretrained(model_name) + self.processor = ChineseCLIPProcessor.from_pretrained(processor_name) else: self.model = CLIPModel.from_pretrained(model_name) self.processor = CLIPProcessor.from_pretrained(processor_name) def __str__(self): - return f"model_name: {self.model_name} CLIPModel({self.model})" + return f"model_name: {self.model_name} ClipModule({self.model})" def forward(self, features): image_embeds = [] @@ -58,7 +71,12 @@ def forward(self, features): output_attentions=features.get('output_attentions', None), output_hidden_states=features.get('output_hidden_states', None), ) - text_embeds = self.model.text_projection(text_outputs[1]) + if self.is_chinese_model: + # refer chinese clip: https://github.com/huggingface/transformers/blob/main/src/transformers/models/chinese_clip/modeling_chinese_clip.py#L1431 + pooled_output = text_outputs[0][:, 0, :] + else: + pooled_output = text_outputs[1] + text_embeds = self.model.text_projection(pooled_output) sentence_embedding = [] image_features = iter(image_embeds) @@ -80,7 +98,7 @@ def tokenize(self, texts): image_text_info = [] for idx, data in enumerate(texts): - if isinstance(data, Image.Image): # An Image + if isinstance(data, (Image.Image, np.ndarray)): # An Image images.append(data) image_text_info.append(0) else: # A text @@ -102,7 +120,7 @@ def save(self, output_path: str): @staticmethod def load(input_path: str): - return CLIPModel(model_name=input_path) + return ClipModule(model_name=input_path) def _text_length(self, text): """ @@ -121,7 +139,7 @@ def _text_length(self, text): return sum([len(t) for t in text]) # Sum of length of individual strings @staticmethod - def batch_to_device(batch): + def batch_to_device(batch, device): """ send a pytorch batch to a device (CPU/GPU) """ @@ -157,7 +175,7 @@ def encode( if isinstance(sentences, str) or not hasattr(sentences, '__len__'): sentences = [sentences] input_was_string = True - self.model.to(device) + self.model.to(self.device) all_embeddings = [] length_sorted_idx = np.argsort([-self._text_length(sent) for sent in sentences]) @@ -166,7 +184,7 @@ def encode( for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): sentences_batch = sentences_sorted[start_index:start_index + batch_size] features = self.tokenize(sentences_batch) - features = self.batch_to_device(features) + features = self.batch_to_device(features, self.device) with torch.no_grad(): out_features = self.forward(features) diff --git a/similarities/clip_similarity.py b/similarities/clip_similarity.py new file mode 100644 index 0000000..a236688 --- /dev/null +++ b/similarities/clip_similarity.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: Image similarity and image retrieval + +refer: https://colab.research.google.com/drive/1leOzG-AQw5MkzgA4qNW5fb3yc-oJ4Lo4 +Adjust the code to compare similarity score and search. +""" + +from typing import List, Union, Dict + +import numpy as np +from PIL import Image +from loguru import logger + +from similarities.clip_module import ClipModule +from similarities.similarity import SimilarityABC +from similarities.utils.util import cos_sim, semantic_search, dot_score + + +class ClipSimilarity(SimilarityABC): + """ + Compute CLIP similarity between two images and retrieves most + similar image for a given image corpus. + + CLIP: https://github.com/openai/CLIP.git + english model: openai/clip-vit-base-patch32 + chinese model: OFA-Sys/chinese-clip-vit-base-patch16 + """ + + def __init__( + self, + corpus: Union[List[Image.Image], Dict[str, Image.Image]] = None, + model_name_or_path='OFA-Sys/chinese-clip-vit-base-patch16' + ): + self.clip_model = ClipModule(model_name_or_path) # load the CLIP model + self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score} + self.corpus = {} + self.corpus_embeddings = [] + if corpus is not None: + self.add_corpus(corpus) + + def __len__(self): + """Get length of corpus.""" + return len(self.corpus) + + def __str__(self): + base = f"Similarity: {self.__class__.__name__}, matching_model: {self.clip_model.__class__.__name__}" + if self.corpus: + base += f", corpus size: {len(self.corpus)}" + return base + + def _convert_to_rgb(self, img): + """Convert image to RGB mode.""" + if img.mode != 'RGB': + img = img.convert('RGB') + return img + + def get_embeddings( + self, + text_or_img: Union[List[Image.Image], Image.Image, str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = False, + ): + """ + Returns the embeddings for a batch of images. + :param text_or_img: list of str or Image.Image or image list + :param batch_size: batch size + :param show_progress_bar: show progress bar + :return: np.ndarray, embeddings for the given images + """ + if isinstance(text_or_img, str): + text_or_img = [text_or_img] + if isinstance(text_or_img, Image.Image): + text_or_img = [text_or_img] + if isinstance(text_or_img, list) and isinstance(text_or_img[0], Image.Image): + text_or_img = [self._convert_to_rgb(i) for i in text_or_img] + return self.clip_model.encode(text_or_img, batch_size=batch_size, show_progress_bar=show_progress_bar) + + def add_corpus(self, corpus: Union[List[Image.Image], Dict[str, Image.Image]]): + """ + Extend the corpus with new documents. + + Parameters + ---------- + corpus : list of str or dict + """ + corpus_new = {} + start_id = len(self.corpus) if self.corpus else 0 + if isinstance(corpus, list): + for id, doc in enumerate(corpus): + if doc not in list(self.corpus.values()): + corpus_new[start_id + id] = doc + else: + for id, doc in corpus.items(): + if doc not in list(self.corpus.values()): + corpus_new[id] = doc + self.corpus.update(corpus_new) + logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}") + corpus_embeddings = self.get_embeddings(list(corpus_new.values()), show_progress_bar=True).tolist() + if self.corpus_embeddings: + self.corpus_embeddings += corpus_embeddings + else: + self.corpus_embeddings = corpus_embeddings + logger.info(f"Add {len(corpus)} docs, total: {len(self.corpus)}, emb size: {len(self.corpus_embeddings)}") + + def similarity( + self, + a: Union[List[Image.Image], Image.Image, str, List[str]], + b: Union[List[Image.Image], Image.Image, str, List[str]], + score_function: str = "cos_sim" + ): + """ + Compute similarity between two texts. + :param a: list of str or str + :param b: list of str or str + :param score_function: function to compute similarity, default cos_sim + :return: similarity score, torch.Tensor, Matrix with res[i][j] = cos_sim(a[i], b[j]) + """ + if score_function not in self.score_functions: + raise ValueError(f"score function: {score_function} must be either (cos_sim) for cosine similarity" + " or (dot) for dot product") + score_function = self.score_functions[score_function] + text_emb1 = self.get_embeddings(a) + text_emb2 = self.get_embeddings(b) + + return score_function(text_emb1, text_emb2) + + def distance(self, a: Union[str, List[str]], b: Union[str, List[str]]): + """Compute cosine distance between two texts.""" + return 1 - self.similarity(a, b) + + def most_similar(self, queries, topn: int = 10): + """ + Find the topn most similar texts to the queries against the corpus. + :param queries: text or image + :param topn: int + :return: Dict[str, Dict[str, float]], {query_id: {corpus_id: similarity_score}, ...} + """ + if isinstance(queries, str) or not hasattr(queries, '__len__'): + queries = [queries] + if isinstance(queries, list): + queries = {id: query for id, query in enumerate(queries)} + result = {qid: {} for qid, query in queries.items()} + queries_ids_map = {i: id for i, id in enumerate(list(queries.keys()))} + queries_texts = list(queries.values()) + queries_embeddings = self.get_embeddings(queries_texts) + corpus_embeddings = np.array(self.corpus_embeddings, dtype=np.float32) + all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn) + for idx, hits in enumerate(all_hits): + for hit in hits[0:topn]: + result[queries_ids_map[idx]][hit['corpus_id']] = hit['score'] + + return result diff --git a/similarities/faiss_bert_similarity.py b/similarities/faiss_bert_similarity.py new file mode 100644 index 0000000..c666ea1 --- /dev/null +++ b/similarities/faiss_bert_similarity.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: Use faiss to build index +""" +import json +import os +from glob import glob +from typing import List, Optional + +import faiss +import fire +import numpy as np +import requests +from loguru import logger +from text2vec import SentenceModel + +from similarities.utils.util import cos_sim + + +def bert_embedding( + input_dir: str, + embeddings_dir: str = 'tmp_embeddings_dir/', + embeddings_name: str = 'emb.npy', + corpus_file: str = 'tmp_data_dir/corpus.npy', + model_name: str = "shibing624/text2vec-base-chinese", + batch_size: int = 32, + device: Optional[str] = None, + normalize_embeddings: bool = False, +): + sentences = set() + input_files = glob(f'{input_dir}/**/*.txt', recursive=True) + logger.info(f'Load input files success. input files: {input_files}') + for file in input_files: + with open(file, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + sentences.add(line) + sentences = list(sentences) + logger.info(f'Load sentences success. sentences num: {len(sentences)}, top3: {sentences[:3]}') + assert len(sentences) > 0, f"sentences is empty, please check input files: {input_files}" + + model = SentenceModel(model_name_or_path=model_name, device=device) + logger.info(f'Load model success. model: {model}') + + # Start the multi processes pool on all available CUDA devices + pool = model.start_multi_process_pool() + + # Compute the embeddings using the multi processes pool + emb = model.encode_multi_process( + sentences, + pool, + batch_size=batch_size, + normalize_embeddings=normalize_embeddings + ) + logger.info(f"Embeddings computed. Shape: {emb.shape}") + + model.stop_multi_process_pool(pool) + # Save the embeddings + os.makedirs(embeddings_dir, exist_ok=True) + embeddings_file = os.path.join(embeddings_dir, embeddings_name) + np.save(embeddings_file, emb) + logger.debug(f"Embeddings saved to {embeddings_file}") + corpus_dir = os.path.dirname(corpus_file) + os.makedirs(corpus_dir, exist_ok=True) + np.save(corpus_file, sentences) + logger.debug(f"Sentences saved to {corpus_file}") + logger.info(f"Input dir: {input_dir}, saved embeddings dir: {embeddings_dir}") + + +def bert_index( + embeddings_dir: str, + index_dir: str = "tmp_index_dir/", + index_name: str = "faiss.index", + max_index_memory_usage: str = "4G", + current_memory_available: str = "8G", + use_gpu: bool = False, + nb_cores: Optional[int] = None, +): + """indexes text embeddings using autofaiss""" + from autofaiss import build_index # pylint: disable=import-outside-toplevel + + logger.debug(f"Starting build index from {embeddings_dir}") + if embeddings_dir and os.path.exists(embeddings_dir): + logger.debug( + f"Embedding path exist, building index " + f"using embeddings {embeddings_dir} ; saving in {index_dir}" + ) + index_file = os.path.join(index_dir, index_name) + index_infos_path = os.path.join(index_dir, index_name + ".json") + try: + build_index( + embeddings=embeddings_dir, + index_path=index_file, + index_infos_path=index_infos_path, + max_index_memory_usage=max_index_memory_usage, + current_memory_available=current_memory_available, + nb_cores=nb_cores, + use_gpu=use_gpu, + ) + logger.info(f"Index {embeddings_dir} done, saved in {index_file}, index infos in {index_infos_path}") + except Exception as e: # pylint: disable=broad-except + logger.error(f"Index {embeddings_dir} failed, {e}") + raise e + else: + logger.warning(f"Embeddings dir {embeddings_dir} not exist") + + +def batch_search_index( + queries, + model, + faiss_index, + sentences, + num_results, + threshold, + debug=True, +): + """Search index with text inputs (batch search)""" + result = [] + # Query embeddings need to be normalized for cosine similarity + query_features = model.encode(queries, normalize_embeddings=True) + + for query, query_feature in zip(queries, query_features): + query_feature = query_feature.reshape(1, -1) + if threshold is not None: + _, d, i = faiss_index.range_search(query_feature, threshold) + if debug: + logger.debug(f"Found {i.shape} items with query '{query}' and threshold {threshold}") + else: + d, i = faiss_index.search(query_feature, num_results) + i = i[0] + d = d[0] + if debug: + logger.debug(f"Found {num_results} items with query '{query}'") + logger.debug(f"The minimum distance is {min(d):.2f} and the maximum is {max(d):.2f}") + logger.debug( + "You may want to increase your result, use --num_results parameter. " + "Or use the --threshold parameter." + ) + # Sorted faiss search result with distance + text_scores = [] + for ed, ei in zip(d, i): + sentence = sentences[ei] + if debug: + logger.debug(f"Found: {sentence}, similarity: {ed}, id: {ei}") + text_scores.append((sentence, float(ed), int(ei))) + # Sort by score desc + query_result = sorted(text_scores, key=lambda x: x[1], reverse=True) + result.append(query_result) + return result + + +def bert_filter( + queries: List[str], + output_file: str = "tmp_outputs/result.json", + model_name: str = "shibing624/text2vec-base-chinese", + index_dir: str = 'tmp_index_dir/', + index_name: str = "faiss.index", + corpus_file: str = "tmp_data_dir/corpus.npy", + num_results: int = 10, + threshold: Optional[float] = None, + device: Optional[str] = None, +): + """Entry point of bert filter""" + assert isinstance(queries, list), f"queries type error, queries: {queries}" + index_file = os.path.join(index_dir, index_name) + assert os.path.exists(index_file), f"index file {index_file} not exist" + faiss_index = faiss.read_index(index_file) + model = SentenceModel(model_name_or_path=model_name, device=device) + sentences = np.load(corpus_file) + logger.info(f'Load model success. model: {model}, index: {faiss_index}, sentences size: {len(sentences)}') + + result = batch_search_index(queries, model, faiss_index, sentences, num_results, threshold) + # Save results + if output_file: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, 'w', encoding='utf-8') as f: + for q, sorted_text_scores in zip(queries, result): + json.dump( + {'query': q, + 'results': [{'sentence': i, 'similarity': j, 'id': k} for i, j, k in sorted_text_scores]}, + f, + ensure_ascii=False, + ) + f.write('\n') + logger.info(f"Query size: {len(queries)}, saved result to {output_file}") + return result + + +def bert_server( + model_name: str = "shibing624/text2vec-base-chinese", + index_dir: str = 'tmp_index_dir/', + index_name: str = "faiss.index", + corpus_file: str = "tmp_data_dir/corpus.npy", + num_results: int = 10, + threshold: Optional[float] = None, + device: Optional[str] = None, + port: int = 8001, + debug: bool = False, +): + """main entry point of bert search backend, start the endpoints""" + import uvicorn + from fastapi import FastAPI + from pydantic import BaseModel, Field + from starlette.middleware.cors import CORSMiddleware + + logger.info("starting boot of bert serve") + index_file = os.path.join(index_dir, index_name) + assert os.path.exists(index_file), f"index file {index_file} not exist" + faiss_index = faiss.read_index(index_file) + model = SentenceModel(model_name_or_path=model_name, device=device) + sentences = np.load(corpus_file) + logger.info(f'Load model success. model: {model}, index: {faiss_index}, sentences size: {len(sentences)}') + + # define the app + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"]) + + class Item(BaseModel): + input: str = Field(..., max_length=512) + + @app.get('/') + async def index(): + return {"message": "index, docs url: /docs"} + + @app.post('/emb') + async def emb(item: Item): + try: + q = item.input + embeddings = model.encode(q) + result_dict = {'emb': embeddings.tolist()} + logger.debug(f"Successfully get sentence embeddings, q:{q}, res shape: {embeddings.shape}") + return result_dict + except Exception as e: + logger.error(e) + return {'status': False, 'msg': e}, 400 + + @app.post('/similarity') + async def similarity(item1: Item, item2: Item): + try: + q1 = item1.input + q2 = item2.input + emb1 = model.encode(q1) + emb2 = model.encode(q2) + sim_score = cos_sim(emb1, emb2).tolist()[0][0] + result_dict = {'similarity': sim_score} + logger.debug(f"Successfully get similarity score, q1:{q1}, q2:{q2}, res: {sim_score}") + return result_dict + except Exception as e: + logger.error(e) + return {'status': False, 'msg': e}, 400 + + @app.post('/search') + async def search(item: Item): + try: + q = item.input + results = batch_search_index([q], model, faiss_index, sentences, num_results, threshold, debug=debug) + sorted_text_scores = results[0][0] + result_dict = {'result': sorted_text_scores} + logger.debug(f"Successfully search done, q:{q}, res size: {len(sorted_text_scores)}") + return result_dict + except Exception as e: + logger.error(f"search error: {e}") + return {'status': False, 'msg': e}, 400 + + logger.info("Server starting!") + uvicorn.run(app, host="0.0.0.0", port=port) + + +class BertClient: + def __init__(self, base_url: str = "http://0.0.0.0:8001"): + self.base_url = base_url + + def _post(self, endpoint: str, data: dict) -> dict: + response = requests.post(f"{self.base_url}/{endpoint}", json=data) + response.raise_for_status() + return response.json() + + def get_emb(self, input_text: str) -> List[float]: + try: + data = {"input": input_text} + response = self._post("emb", data) + return response.get("emb", []) + except Exception as e: + logger.error(f"get_emb error: {e}") + return [] + + def get_similarity(self, input_text1: str, input_text2: str) -> float: + try: + data1 = {"input": input_text1} + data2 = {"input": input_text2} + response = self._post("similarity", {"item1": data1, "item2": data2}) + return response.get("similarity", 0.0) + except Exception as e: + logger.error(f"get_similarity error: {e}") + return 0.0 + + def search(self, input_text: str): + try: + data = {"input": input_text} + response = self._post("search", data) + return response.get("result", []) + except Exception as e: + logger.error(f"search error: {e}") + return [] + + +def main(): + """Main entry point""" + fire.Fire( + { + "bert_embedding": bert_embedding, + "bert_index": bert_index, + "bert_filter": bert_filter, + "bert_server": bert_server, + } + ) + + +if __name__ == "__main__": + main() diff --git a/similarities/faiss_clip_similarity.py b/similarities/faiss_clip_similarity.py new file mode 100644 index 0000000..ea4b1e2 --- /dev/null +++ b/similarities/faiss_clip_similarity.py @@ -0,0 +1,492 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: Use faiss to search clip embeddings +""" +import base64 +import json +import os +from io import BytesIO +from typing import Sequence, List, Optional, Union + +import faiss +import fire +import numpy as np +import pandas as pd +import requests +from PIL import Image +from loguru import logger +from pydantic import BaseModel, Field + +from similarities.clip_module import ClipModule +from similarities.utils.util import cos_sim + + +def load_data(data, header=None, columns=('image_path', 'text'), delimiter='\t'): + """ + Encoding data_list text + @param data: list of (image_path, text) or DataFrame or file path + @param header: read_csv header + @param columns: read_csv names + @param delimiter: read_csv sep + @return: data_df + """ + if isinstance(data, list): + data_df = pd.DataFrame(data, columns=columns) + elif isinstance(data, str) and os.path.exists(data): + data_df = pd.read_csv(data, header=header, delimiter=delimiter, names=columns) + elif isinstance(data, pd.DataFrame): + data_df = data + else: + raise TypeError('should be list or file path') + return data_df + + +def preprocess_image(image_input: Union[str, np.ndarray, bytes]) -> Image.Image: + """ + Process image input to Image.Image object + """ + if isinstance(image_input, str): + if image_input.startswith('http'): + return Image.open(requests.get(image_input, stream=True).raw) + elif image_input.endswith((".png", ".jpg", ".jpeg", ".bmp")) and os.path.isfile(image_input): + return Image.open(image_input) + else: + raise ValueError("Unsupported image input type") + elif isinstance(image_input, np.ndarray): + return Image.fromarray(image_input) + elif isinstance(image_input, bytes): + img_data = base64.b64decode(image_input) + return Image.open(BytesIO(img_data)) + else: + raise ValueError("Unsupported image input type") + + +def clip_embedding( + input_data_or_path: str, + columns: Optional[Sequence[str]] = ('image_path', 'text'), + header: Optional[int] = None, + delimiter: str = '\t', + image_embeddings_dir: Optional[str] = 'tmp_image_embeddings_dir/', + text_embeddings_dir: Optional[str] = 'tmp_text_embeddings_dir/', + embeddings_name: str = 'emb.npy', + corpus_file: str = 'tmp_data_dir/corpus.csv', + model_name: str = "OFA-Sys/chinese-clip-vit-base-patch16", + batch_size: int = 32, + enable_image: bool = True, + enabel_text: bool = False, + device: Optional[str] = None, + normalize_embeddings: bool = False, +): + """Embedding text and image with clip model""" + df = load_data(input_data_or_path, header=header, columns=columns, delimiter=delimiter) + logger.info(f'Load data success. data num: {len(df)}, top3: {df.head(3)}') + images = df['image_path'].tolist() + texts = df['text'].tolist() + model = ClipModule(model_name=model_name, device=device) + logger.info(f'Load model success. model: {model_name}') + + # Start the multi processes pool on all available CUDA devices + if enable_image: + os.makedirs(image_embeddings_dir, exist_ok=True) + images = [preprocess_image(img) for img in images] + image_emb = model.encode( + images, + batch_size=batch_size, + show_progress_bar=True, + normalize_embeddings=normalize_embeddings, + ) + logger.info(f"Embeddings computed. Shape: {image_emb.shape}") + image_embeddings_file = os.path.join(image_embeddings_dir, embeddings_name) + np.save(image_embeddings_file, image_emb) + logger.debug(f"Embeddings saved to {image_embeddings_file}") + if enabel_text: + os.makedirs(text_embeddings_dir, exist_ok=True) + text_emb = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=True, + normalize_embeddings=normalize_embeddings, + ) + logger.info(f"Embeddings computed. Shape: {text_emb.shape}") + text_embeddings_file = os.path.join(text_embeddings_dir, embeddings_name) + np.save(text_embeddings_file, text_emb) + logger.debug(f"Embeddings saved to {text_embeddings_file}") + + # Save corpus + if corpus_file: + os.makedirs(os.path.dirname(corpus_file), exist_ok=True) + df.to_csv(corpus_file, index=False) + logger.debug(f"data saved to {corpus_file}") + + +def clip_index( + image_embeddings_dir: Optional[str] = None, + text_embeddings_dir: Optional[str] = None, + image_index_dir: Optional[str] = "tmp_image_index_dir/", + text_index_dir: Optional[str] = "tmp_text_index_dir/", + index_name: str = "faiss.index", + max_index_memory_usage: str = "4G", + current_memory_available: str = "16G", + use_gpu: bool = False, + nb_cores: Optional[int] = None, +): + """indexes text embeddings using autofaiss""" + from autofaiss import build_index # pylint: disable=import-outside-toplevel + + logger.debug(f"Starting build index from {image_embeddings_dir}") + if image_embeddings_dir and os.path.exists(image_embeddings_dir): + logger.debug( + f"Embedding path exist, building index " + f"using embeddings {image_embeddings_dir} ; saving in {image_index_dir}" + ) + index_file = os.path.join(image_index_dir, index_name) + index_infos_path = os.path.join(image_index_dir, index_name + ".json") + try: + build_index( + embeddings=image_embeddings_dir, + index_path=index_file, + index_infos_path=index_infos_path, + max_index_memory_usage=max_index_memory_usage, + current_memory_available=current_memory_available, + nb_cores=nb_cores, + use_gpu=use_gpu, + ) + logger.info(f"Index {image_embeddings_dir} done, saved in {index_file}, index infos in {index_infos_path}") + except Exception as e: # pylint: disable=broad-except + logger.error(f"Index {image_embeddings_dir} failed, {e}") + raise e + else: + logger.warning(f"Embeddings dir {image_embeddings_dir} not exist") + + logger.debug(f"Starting build index from {text_embeddings_dir}") + if text_embeddings_dir and os.path.exists(text_embeddings_dir): + logger.debug( + f"Embedding path exist, building index " + f"using embeddings {text_embeddings_dir} ; saving in {text_index_dir}" + ) + index_file = os.path.join(text_index_dir, index_name) + index_infos_path = os.path.join(text_index_dir, index_name + ".json") + try: + build_index( + embeddings=text_embeddings_dir, + index_path=index_file, + index_infos_path=index_infos_path, + max_index_memory_usage=max_index_memory_usage, + current_memory_available=current_memory_available, + nb_cores=nb_cores, + use_gpu=use_gpu, + ) + logger.info(f"Index {text_embeddings_dir} done, saved in {index_file}, index infos in {index_infos_path}") + except Exception as e: # pylint: disable=broad-except + logger.error(f"Index {text_embeddings_dir} failed, {e}") + raise e + else: + logger.warning(f"Embeddings dir {text_embeddings_dir} not exist") + + +def batch_search_index( + queries, + model, + faiss_index, + df, + num_results, + threshold, + debug=True, +): + """Search index with image inputs or image paths (batch search)""" + assert queries is not None, "queries should not be None" + result = [] + if isinstance(queries, np.ndarray): + query_features = queries + else: + query_features = model.encode(queries, normalize_embeddings=True) + + for query, query_feature in zip(queries, query_features): + query_feature = query_feature.reshape(1, -1) + if threshold is not None: + _, d, i = faiss_index.range_search(query_feature, threshold) + if debug: + logger.debug(f"Found {i.shape} items with query '{query}' and threshold {threshold}") + else: + d, i = faiss_index.search(query_feature, num_results) + i = i[0] + d = d[0] + if debug: + logger.debug(f"Found {num_results} items with query '{query}'") + logger.debug(f"The minimum distance is {min(d):.2f} and the maximum is {max(d):.2f}") + logger.debug( + "You may want to increase your result, use --num_results parameter. " + "Or use the --threshold parameter." + ) + # Sorted faiss search result with distance + text_scores = [] + for ed, ei in zip(d, i): + item = df.iloc[ei].to_dict() + if debug: + logger.debug(f"Found: {item}, similarity: {ed}, id: {ei}") + text_scores.append((item, float(ed), int(ei))) + # Sort by score desc + query_result = sorted(text_scores, key=lambda x: x[1], reverse=True) + result.append(query_result) + return result + + +def clip_filter( + texts: Optional[List[str]] = None, + images: Optional[List[str]] = None, + embeddings: Optional[Union[np.ndarray, List[str]]] = None, + output_file: str = "tmp_outputs/result.json", + model_name: str = "OFA-Sys/chinese-clip-vit-base-patch16", + index_dir: str = 'tmp_image_index_dir/', + index_name: str = "faiss.index", + corpus_file: str = 'tmp_data_dir/corpus.csv', + num_results: int = 10, + threshold: Optional[float] = None, + device: Optional[str] = None, +): + """Entry point of clip filter""" + if texts is None and images is None and embeddings is None: + raise ValueError("must fill one of texts, images and embeddings input") + queries = None + if texts is not None and len(texts) > 0: + queries = texts + elif images is not None and len(images) > 0: + queries = [preprocess_image(img) for img in images] + elif embeddings is not None: + queries = embeddings + if isinstance(queries, list): + queries = np.array(queries, dtype=np.float32) + if len(queries.shape) == 1: + queries = np.expand_dims(queries, axis=0) + + index_file = os.path.join(index_dir, index_name) + assert os.path.exists(index_file), f"index file {index_file} not exist" + faiss_index = faiss.read_index(index_file) + model = ClipModule(model_name=model_name, device=device) + df = pd.read_csv(corpus_file) + logger.info(f'Load model success. model: {model_name}, index: {faiss_index}, data size: {len(df)}') + + result = batch_search_index(queries, model, faiss_index, df, num_results, threshold) + # Save results + if output_file: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, 'w', encoding='utf-8') as f: + if texts: + for q, sorted_text_scores in zip(texts, result): + json.dump( + {'text': q, + 'results': [{'sentence': i, 'similarity': j, 'id': k} for i, j, k in sorted_text_scores]}, + f, + ensure_ascii=False, + ) + f.write('\n') + logger.info(f"Query texts size: {len(texts)}, saved result to {output_file}") + elif images: + for q, sorted_text_scores in zip(images, result): + json.dump( + {'image': q, + 'results': [{'sentence': i, 'similarity': j, 'id': k} for i, j, k in sorted_text_scores]}, + f, + ensure_ascii=False, + ) + f.write('\n') + logger.info(f"Query images size: {len(images)}, saved result to {output_file}") + elif embeddings is not None: + for q, sorted_text_scores in zip(queries, result): + json.dump( + {'emb': q.tolist(), + 'results': [{'sentence': i, 'similarity': j, 'id': k} for i, j, k in sorted_text_scores]}, + f, + ensure_ascii=False, + ) + f.write('\n') + logger.info(f"Query embeddings size: {len(embeddings)}, saved result to {output_file}") + return result + + +class Item(BaseModel): + input: str = Field(..., max_length=512) + + +class ClipItem(BaseModel): + text: Optional[str] = Field(None, max_length=512) + image: Optional[str] = None + + +class SearchItem(BaseModel): + text: Optional[str] = Field(None, max_length=512) + image: Optional[str] = None + emb: Optional[List[float]] = None + + +def clip_server( + model_name: str = "OFA-Sys/chinese-clip-vit-base-patch16", + index_dir: str = 'tmp_image_index_dir/', + index_name: str = "faiss.index", + corpus_file: str = 'tmp_data_dir/corpus.csv', + num_results: int = 10, + threshold: Optional[float] = None, + device: Optional[str] = None, + port: int = 8002, + debug: bool = False, +): + """main entry point of clip search backend, start the endpoints""" + import uvicorn + from fastapi import FastAPI + from starlette.middleware.cors import CORSMiddleware + + print("starting boot of clip serve") + index_file = os.path.join(index_dir, index_name) + assert os.path.exists(index_file), f"index file {index_file} not exist" + faiss_index = faiss.read_index(index_file) + model = ClipModule(model_name=model_name, device=device) + df = pd.read_csv(corpus_file) + logger.info(f'Load model success. model: {model_name}, index: {faiss_index}, data size: {len(df)}') + + # define the app + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"]) + + @app.get('/') + async def index(): + return {"message": "index, docs url: /docs"} + + @app.post('/emb') + async def emb(item: ClipItem): + try: + if item.text is not None: + q = [item.text] + elif item.image is not None: + q = [preprocess_image(item.image)] + else: + raise ValueError("item should have text or image") + embeddings = model.encode(q) + result_dict = {'emb': embeddings.tolist()[0]} + logger.debug(f"Successfully get embeddings, res shape: {embeddings.shape}") + return result_dict + except Exception as e: + logger.error(e) + return {'status': False, 'msg': e}, 400 + + @app.post('/similarity') + async def similarity(item1: ClipItem, item2: ClipItem): + try: + if item1.text is not None: + q1 = item1.text + elif item1.image is not None: + q1 = preprocess_image(item1.image) + else: + raise ValueError("item1 should have text or image") + if item2.text is not None: + q2 = item2.text + elif item2.image is not None: + q2 = preprocess_image(item2.image) + else: + raise ValueError("item2 should have text or image") + emb1 = model.encode(q1) + emb2 = model.encode(q2) + sim_score = cos_sim(emb1, emb2).tolist()[0][0] + result_dict = {'similarity': sim_score} + logger.debug(f"Successfully get similarity score, res: {sim_score}") + return result_dict + except Exception as e: + logger.error(e) + return {'status': False, 'msg': e}, 400 + + @app.post('/search') + async def search(item: SearchItem): + try: + if item.text is not None: + q = [item.text] + elif item.image is not None: + q = [preprocess_image(item.image)] + elif item.emb is not None: + q = item.emb + if isinstance(q, list): + q = np.array(q, dtype=np.float32) + if len(q.shape) == 1: + q = np.expand_dims(q, axis=0) + else: + raise ValueError("item should have text or image or emb") + results = batch_search_index(q, model, faiss_index, df, num_results, threshold, debug=debug) + sorted_text_scores = results[0] + result_dict = {'result': sorted_text_scores} + logger.debug(f"Successfully search done, res size: {len(sorted_text_scores)}") + return result_dict + except Exception as e: + logger.error(f"search error: {e}") + return {'status': False, 'msg': e}, 400 + + logger.info("Server starting!") + uvicorn.run(app, host="0.0.0.0", port=port) + + +class ClipClient: + def __init__(self, base_url: str = "http://0.0.0.0:8002"): + self.base_url = base_url + + def _post(self, endpoint: str, data: dict) -> dict: + response = requests.post(f"{self.base_url}/{endpoint}", json=data) + response.raise_for_status() + return response.json() + + def get_emb(self, text: Optional[str] = None, image: Optional[str] = None) -> List[float]: + try: + data = { + "text": text, + "image": image, + } + response = self._post("emb", data) + return response.get("emb", []) + except Exception as e: + logger.error(e) + return [] + + def get_similarity(self, item1: ClipItem, item2: ClipItem) -> float: + try: + data = {"item1": item1.dict(), "item2": item2.dict()} + response = self._post("similarity", data) + return response.get("similarity", 0.0) + except Exception as e: + logger.error(f"Error: {e}") + return 0.0 + + def search( + self, + text: Optional[str] = None, + image: Optional[str] = None, + emb: Optional[List[float]] = None + ): + try: + data = { + "text": text, + "image": image, + "emb": emb + } + response = self._post("search", data) + return response.get("result", []) + except Exception as e: + logger.error(f"Error: {e}") + return [] + + +def main(): + """Main entry point""" + fire.Fire( + { + "clip_embedding": clip_embedding, + "clip_index": clip_index, + "clip_filter": clip_filter, + "clip_server": clip_server, + } + ) + + +if __name__ == "__main__": + main() diff --git a/similarities/fastsim.py b/similarities/fast_bert_similarity.py similarity index 94% rename from similarities/fastsim.py rename to similarities/fast_bert_similarity.py index 78a2a0a..0238919 100644 --- a/similarities/fastsim.py +++ b/similarities/fast_bert_similarity.py @@ -8,10 +8,10 @@ from loguru import logger -from similarities.similarity import Similarity +from similarities.bert_similarity import BertSimilarity -class AnnoySimilarity(Similarity): +class AnnoySimilarity(BertSimilarity): """ Computes cosine similarities between word embeddings and retrieves most similar query for a given docs with Annoy. @@ -63,7 +63,7 @@ def save_index(self, index_path: str = "annoy_index.bin"): self.build_index() self.index.save(index_path) corpus_emb_json_path = index_path + ".json" - super().save_index(corpus_emb_json_path) + super().save_embeddings(corpus_emb_json_path) logger.info(f"Saving Annoy index to: {index_path}, corpus embedding to: {corpus_emb_json_path}") else: logger.warning("No index path given. Index not saved.") @@ -73,7 +73,7 @@ def load_index(self, index_path: str = "annoy_index.bin"): if index_path and os.path.exists(index_path): corpus_emb_json_path = index_path + ".json" logger.info(f"Loading index from: {index_path}, corpus embedding from: {corpus_emb_json_path}") - super().load_index(corpus_emb_json_path) + super().load_embeddings(corpus_emb_json_path) if self.index is None: self.create_index() self.index.load(index_path) @@ -97,7 +97,7 @@ def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int queries = {id: query for id, query in enumerate(queries)} result = {qid: {} for qid, query in queries.items()} queries_texts = list(queries.values()) - queries_embeddings = self._get_vector(queries_texts) + queries_embeddings = self.get_embeddings(queries_texts) # Annoy get_nns_by_vector can only search for one vector at a time for idx, (qid, query) in enumerate(queries.items()): corpus_ids, distances = self.index.get_nns_by_vector(queries_embeddings[idx], topn, include_distances=True) @@ -108,7 +108,7 @@ def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int return result -class HnswlibSimilarity(Similarity): +class HnswlibSimilarity(BertSimilarity): """ Computes cosine similarities between word embeddings and retrieves most similar query for a given docs with Hnswlib. @@ -168,7 +168,7 @@ def save_index(self, index_path: str = "hnswlib_index.bin"): self.build_index() self.index.save_index(index_path) corpus_emb_json_path = index_path + ".json" - super().save_index(corpus_emb_json_path) + super().save_embeddings(corpus_emb_json_path) logger.info(f"Saving hnswlib index to: {index_path}, corpus embedding to: {corpus_emb_json_path}") else: logger.warning("No index path given. Index not saved.") @@ -178,7 +178,7 @@ def load_index(self, index_path: str = "hnswlib_index.bin"): if index_path and os.path.exists(index_path): corpus_emb_json_path = index_path + ".json" logger.info(f"Loading index from: {index_path}, corpus embedding from: {corpus_emb_json_path}") - super().load_index(corpus_emb_json_path) + super().load_embeddings(corpus_emb_json_path) if self.index is None: self.create_index() self.index.load_index(index_path) @@ -202,7 +202,7 @@ def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int queries = {id: query for id, query in enumerate(queries)} result = {qid: {} for qid, query in queries.items()} queries_texts = list(queries.values()) - queries_embeddings = self._get_vector(queries_texts) + queries_embeddings = self.get_embeddings(queries_texts) # We use hnswlib knn_query method to find the top_k_hits corpus_ids, distances = self.index.knn_query(queries_embeddings, k=topn) # We extract corpus ids and scores for each query diff --git a/similarities/imagesim.py b/similarities/image_similarity.py similarity index 67% rename from similarities/imagesim.py rename to similarities/image_similarity.py index d83adb1..a03524d 100644 --- a/similarities/imagesim.py +++ b/similarities/image_similarity.py @@ -15,147 +15,9 @@ from loguru import logger from tqdm import tqdm -from similarities.clip_model import CLIPModel from similarities.similarity import SimilarityABC from similarities.utils.distance import hamming_distance from similarities.utils.imagehash import phash, dhash, whash, average_hash -from similarities.utils.util import cos_sim, semantic_search, dot_score - - -class ClipSimilarity(SimilarityABC): - """ - Compute CLIP similarity between two images and retrieves most - similar image for a given image corpus. - - CLIP: https://github.com/openai/CLIP.git - english model: openai/clip-vit-base-patch32 - chinese model: OFA-Sys/chinese-clip-vit-base-patch16 - """ - - def __init__( - self, - corpus: Union[List[Image.Image], Dict[str, Image.Image]] = None, - model_name_or_path='OFA-Sys/chinese-clip-vit-base-patch16' - ): - self.clip_model = CLIPModel(model_name_or_path) # load the CLIP model - self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score} - self.corpus = {} - self.corpus_embeddings = [] - if corpus is not None: - self.add_corpus(corpus) - - def __len__(self): - """Get length of corpus.""" - return len(self.corpus) - - def __str__(self): - base = f"Similarity: {self.__class__.__name__}, matching_model: {self.clip_model.__class__.__name__}" - if self.corpus: - base += f", corpus size: {len(self.corpus)}" - return base - - def _convert_to_rgb(self, img): - """Convert image to RGB mode.""" - if img.mode != 'RGB': - img = img.convert('RGB') - return img - - def _get_vector( - self, - text_or_img: Union[List[Image.Image], Image.Image, str, List[str]], - batch_size: int = 128, - show_progress_bar: bool = False, - ): - """ - Returns the embeddings for a batch of images. - :param text_or_img: list of str or Image.Image or image list - :param batch_size: batch size - :param show_progress_bar: show progress bar - :return: np.ndarray, embeddings for the given images - """ - if isinstance(text_or_img, str): - text_or_img = [text_or_img] - if isinstance(text_or_img, Image.Image): - text_or_img = [text_or_img] - if isinstance(text_or_img, list) and isinstance(text_or_img[0], Image.Image): - text_or_img = [self._convert_to_rgb(i) for i in text_or_img] - return self.clip_model.encode(text_or_img, batch_size=batch_size, show_progress_bar=show_progress_bar) - - def add_corpus(self, corpus: Union[List[Image.Image], Dict[str, Image.Image]]): - """ - Extend the corpus with new documents. - - Parameters - ---------- - corpus : list of str or dict - """ - corpus_new = {} - start_id = len(self.corpus) if self.corpus else 0 - if isinstance(corpus, list): - for id, doc in enumerate(corpus): - if doc not in list(self.corpus.values()): - corpus_new[start_id + id] = doc - else: - for id, doc in corpus.items(): - if doc not in list(self.corpus.values()): - corpus_new[id] = doc - self.corpus.update(corpus_new) - logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}") - corpus_embeddings = self._get_vector(list(corpus_new.values()), show_progress_bar=True).tolist() - if self.corpus_embeddings: - self.corpus_embeddings += corpus_embeddings - else: - self.corpus_embeddings = corpus_embeddings - logger.info(f"Add {len(corpus)} docs, total: {len(self.corpus)}, emb size: {len(self.corpus_embeddings)}") - - def similarity( - self, - a: Union[List[Image.Image], Image.Image, str, List[str]], - b: Union[List[Image.Image], Image.Image, str, List[str]], - score_function: str = "cos_sim" - ): - """ - Compute similarity between two texts. - :param a: list of str or str - :param b: list of str or str - :param score_function: function to compute similarity, default cos_sim - :return: similarity score, torch.Tensor, Matrix with res[i][j] = cos_sim(a[i], b[j]) - """ - if score_function not in self.score_functions: - raise ValueError(f"score function: {score_function} must be either (cos_sim) for cosine similarity" - " or (dot) for dot product") - score_function = self.score_functions[score_function] - text_emb1 = self._get_vector(a) - text_emb2 = self._get_vector(b) - - return score_function(text_emb1, text_emb2) - - def distance(self, a: Union[str, List[str]], b: Union[str, List[str]]): - """Compute cosine distance between two texts.""" - return 1 - self.similarity(a, b) - - def most_similar(self, queries, topn: int = 10): - """ - Find the topn most similar texts to the queries against the corpus. - :param queries: text or image - :param topn: int - :return: Dict[str, Dict[str, float]], {query_id: {corpus_id: similarity_score}, ...} - """ - if isinstance(queries, str) or not hasattr(queries, '__len__'): - queries = [queries] - if isinstance(queries, list): - queries = {id: query for id, query in enumerate(queries)} - result = {qid: {} for qid, query in queries.items()} - queries_ids_map = {i: id for i, id in enumerate(list(queries.keys()))} - queries_texts = list(queries.values()) - queries_embeddings = self._get_vector(queries_texts) - corpus_embeddings = np.array(self.corpus_embeddings, dtype=np.float32) - all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn) - for idx, hits in enumerate(all_hits): - for hit in hits[0:topn]: - result[queries_ids_map[idx]][hit['corpus_id']] = hit['score'] - - return result class ImageHashSimilarity(SimilarityABC): diff --git a/similarities/literalsim.py b/similarities/literal_similarity.py similarity index 99% rename from similarities/literalsim.py rename to similarities/literal_similarity.py index f057f8d..52e2b68 100644 --- a/similarities/literalsim.py +++ b/similarities/literal_similarity.py @@ -16,7 +16,7 @@ import jieba.posseg import numpy as np from loguru import logger -from text2vec import Word2Vec + from tqdm import tqdm from similarities.similarity import SimilarityABC @@ -375,9 +375,13 @@ class WordEmbeddingSimilarity(SimilarityABC): def __init__(self, corpus: Union[List[str], Dict[str, str]] = None, model_name_or_path="w2v-light-tencent-chinese"): """ Init WordEmbeddingSimilarity. - :param model_name_or_path: ~text2vec.Word2Vec model name or path to model file. + :param model_name_or_path: Word2Vec model name or path to model file. :param corpus: list of str """ + try: + from text2vec import Word2Vec + except ImportError: + raise ImportError("Please install text2vec first, `pip install text2vec`") if isinstance(model_name_or_path, str): self.keyedvectors = Word2Vec(model_name_or_path) elif hasattr(model_name_or_path, "encode"): diff --git a/similarities/similarity.py b/similarities/similarity.py index 80fe98b..afc6e2f 100644 --- a/similarities/similarity.py +++ b/similarities/similarity.py @@ -8,19 +8,8 @@ 2. Retrieves most similar sentence of a query against a corpus of documents. """ -import json -import os from typing import List, Union, Dict -import numpy as np -from loguru import logger -from text2vec import SentenceModel - -from similarities.utils.util import cos_sim, semantic_search, dot_score - -os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" - class SimilarityABC: """ @@ -64,184 +53,11 @@ def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int """ raise NotImplementedError("cannot instantiate Abstract Base Class") - -class Similarity(SimilarityABC): - """ - Sentence Similarity: - 1. Compute the similarity between two sentences - 2. Retrieves most similar sentence of a query against a corpus of documents. - - The index supports adding new documents dynamically. - """ - - def __init__( - self, - corpus: Union[List[str], Dict[str, str]] = None, - model_name_or_path="shibing624/text2vec-base-chinese", - encoder_type="MEAN", - max_seq_length=128, - device=None, - ): - """ - Initialize the similarity object. - :param model_name_or_path: Transformer model name or path, like: - 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'bert-base-uncased', 'bert-base-chinese', - 'shibing624/text2vec-base-chinese', ... - model in HuggingFace Model Hub and release from https://github.com/shibing624/text2vec - :param corpus: Corpus of documents to use for similarity queries. - :param max_seq_length: Max sequence length for sentence model. - """ - if isinstance(model_name_or_path, str): - self.sentence_model = SentenceModel( - model_name_or_path, - encoder_type=encoder_type, - max_seq_length=max_seq_length, - device=device - ) - elif hasattr(model_name_or_path, "encode"): - self.sentence_model = model_name_or_path - else: - raise ValueError("model_name_or_path is transformers model name or path") - self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score} - self.corpus = {} - self.corpus_embeddings = [] - if corpus is not None: - self.add_corpus(corpus) - - def __len__(self): - """Get length of corpus.""" - return len(self.corpus) - - def __str__(self): - base = f"Similarity: {self.__class__.__name__}, matching_model: {self.sentence_model}" - if self.corpus: - base += f", corpus size: {len(self.corpus)}" - return base - - def get_sentence_embedding_dimension(self): - """ - Get the dimension of the sentence embeddings. - - Returns - ------- - int or None - The dimension of the sentence embeddings, or None if it cannot be determined. - """ - if hasattr(self.sentence_model, "get_sentence_embedding_dimension"): - return self.sentence_model.get_sentence_embedding_dimension() - else: - return getattr(self.sentence_model.bert.pooler.dense, "out_features", None) - - def add_corpus(self, corpus: Union[List[str], Dict[str, str]]): + def search(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10): """ - Extend the corpus with new documents. - :param corpus: corpus of documents to use for similarity queries. - :return: self.corpus, self.corpus embeddings - """ - new_corpus = {} - start_id = len(self.corpus) if self.corpus else 0 - for id, doc in enumerate(corpus): - if isinstance(corpus, list): - if doc not in self.corpus.values(): - new_corpus[start_id + id] = doc - else: - if doc not in self.corpus.values(): - new_corpus[id] = doc - self.corpus.update(new_corpus) - logger.info(f"Start computing corpus embeddings, new docs: {len(new_corpus)}") - corpus_embeddings = self._get_vector(list(new_corpus.values()), show_progress_bar=True).tolist() - self.corpus_embeddings = self.corpus_embeddings + corpus_embeddings \ - if self.corpus_embeddings else corpus_embeddings - logger.info(f"Add {len(new_corpus)} docs, total: {len(self.corpus)}, emb len: {len(self.corpus_embeddings)}") - - def _get_vector( - self, - sentences: Union[str, List[str]], - batch_size: int = 64, - show_progress_bar: bool = False, - ) -> np.ndarray: - """ - Returns the embeddings for a batch of sentences. - :param sentences: - :return: - """ - return self.sentence_model.encode(sentences, batch_size=batch_size, show_progress_bar=show_progress_bar) - - def similarity(self, a: Union[str, List[str]], b: Union[str, List[str]], score_function: str = "cos_sim"): - """ - Compute similarity between two texts. - :param a: list of str or str - :param b: list of str or str - :param score_function: function to compute similarity, default cos_sim - :return: similarity score, torch.Tensor, Matrix with res[i][j] = cos_sim(a[i], b[j]) - """ - if score_function not in self.score_functions: - raise ValueError(f"score function: {score_function} must be either (cos_sim) for cosine similarity" - " or (dot) for dot product") - score_function = self.score_functions[score_function] - text_emb1 = self._get_vector(a) - text_emb2 = self._get_vector(b) - - return score_function(text_emb1, text_emb2) - - def distance(self, a: Union[str, List[str]], b: Union[str, List[str]]): - """Compute cosine distance between two texts.""" - return 1 - self.similarity(a, b) - - def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10, - score_function: str = "cos_sim"): - """ - Find the topn most similar texts to the queries against the corpus. - :param queries: str or list of str + Find the topn most similar texts to the query against the corpus. + :param queries: Dict[str(query_id), str(query_text)] or List[str] or str :param topn: int - :param score_function: function to compute similarity, default cos_sim :return: Dict[str, Dict[str, float]], {query_id: {corpus_id: similarity_score}, ...} """ - if isinstance(queries, str) or not hasattr(queries, '__len__'): - queries = [queries] - if isinstance(queries, list): - queries = {id: query for id, query in enumerate(queries)} - if score_function not in self.score_functions: - raise ValueError(f"score function: {score_function} must be either (cos_sim) for cosine similarity" - " or (dot) for dot product") - score_function = self.score_functions[score_function] - result = {qid: {} for qid, query in queries.items()} - queries_ids_map = {i: id for i, id in enumerate(list(queries.keys()))} - queries_texts = list(queries.values()) - queries_embeddings = self._get_vector(queries_texts) - corpus_embeddings = np.array(self.corpus_embeddings, dtype=np.float32) - all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn, score_function=score_function) - for idx, hits in enumerate(all_hits): - for hit in hits[0:topn]: - result[queries_ids_map[idx]][hit['corpus_id']] = hit['score'] - - return result - - def save_index(self, index_path: str = "corpus_emb.json"): - """ - Save corpus embeddings to json file. - :param index_path: json file path - :return: - """ - corpus_emb = {id: {"doc": self.corpus[id], "doc_emb": emb} for id, emb in - zip(self.corpus.keys(), self.corpus_embeddings)} - with open(index_path, "w", encoding="utf-8") as f: - json.dump(corpus_emb, f, ensure_ascii=False) - logger.debug(f"Save corpus embeddings to file: {index_path}.") - - def load_index(self, index_path: str = "corpus_emb.json"): - """ - Load corpus embeddings from json file. - :param index_path: json file path - :return: list of corpus embeddings, dict of corpus ids map, dict of corpus - """ - try: - with open(index_path, "r", encoding="utf-8") as f: - corpus_emb = json.load(f) - corpus_embeddings = [] - for id, corpus_dict in corpus_emb.items(): - self.corpus[int(id)] = corpus_dict["doc"] - corpus_embeddings.append(corpus_dict["doc_emb"]) - self.corpus_embeddings = corpus_embeddings - except (IOError, json.JSONDecodeError): - logger.error("Error: Could not load corpus embeddings from file.") + return self.most_similar(queries, topn=topn) diff --git a/similarities/utils/image_util.py b/similarities/utils/image_util.py new file mode 100644 index 0000000..155c4f4 --- /dev/null +++ b/similarities/utils/image_util.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: +""" + +import base64 +import sys +from io import BytesIO + +import cv2 +import numpy as np +import requests +from PIL import Image +from loguru import logger +from tqdm import tqdm + + +def is_link(s): + return s is not None and s.startswith('http') + + +def img_decode(content: bytes): + np_arr = np.frombuffer(content, dtype=np.uint8) + return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) + + +def download_with_progressbar(url, save_path): + response = requests.get(url, stream=True) + if response.status_code == 200: + total_size_in_bytes = int(response.headers.get('content-length', 1)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm( + total=total_size_in_bytes, unit='iB', unit_scale=True) + with open(save_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + else: + logger.error("Something went wrong while downloading models") + sys.exit(0) + + +def check_img(img): + if isinstance(img, bytes): + img = img_decode(img) + if isinstance(img, str): + # download net image + if is_link(img): + download_with_progressbar(img, 'tmp.jpg') + img = 'tmp.jpg' + image_file = img + with open(image_file, 'rb') as f: + img_str = f.read() + img = img_decode(img_str) + if img is None: + try: + buf = BytesIO() + image = BytesIO(img_str) + im = Image.open(image) + rgb = im.convert('RGB') + rgb.save(buf, 'jpeg') + buf.seek(0) + image_bytes = buf.read() + data_base64 = str(base64.b64encode(image_bytes), + encoding="utf-8") + image_decode = base64.b64decode(data_base64) + img_array = np.frombuffer(image_decode, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + except: + logger.error("error in loading image:{}".format(image_file)) + return None + if img is None: + logger.error("error in loading image:{}".format(image_file)) + return None + if isinstance(img, np.ndarray) and len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + return img + + +def alpha_to_color(img, alpha_color=(255, 255, 255)): + if len(img.shape) == 3 and img.shape[2] == 4: + B, G, R, A = cv2.split(img) + alpha = A / 255 + + R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8) + G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8) + B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8) + + img = cv2.merge((B, G, R)) + return img + + +def preprocess_image(img, alpha_color=(255, 255, 255)): + """ + preprocess image + :param img: + :param alpha_color: + :return: + """ + img = check_img(img) + if img is None: + return None + img = alpha_to_color(img, alpha_color) + return img diff --git a/similarities/version.py b/similarities/version.py index eb0ced8..1a72d32 100644 --- a/similarities/version.py +++ b/similarities/version.py @@ -1,7 +1 @@ -# -*- coding: utf-8 -*- -""" -@author:XuMing(xuming624@qq.com) -@description: -""" - -__version__ = '1.0.6' +__version__ = '1.1.0' diff --git a/tests/test_fastsim.py b/tests/test_fastsim.py index 0f534cd..7cd19cd 100644 --- a/tests/test_fastsim.py +++ b/tests/test_fastsim.py @@ -8,8 +8,8 @@ import unittest sys.path.append('..') -from similarities.fastsim import AnnoySimilarity -from similarities.fastsim import HnswlibSimilarity +from similarities.fast_bert_similarity import AnnoySimilarity +from similarities.fast_bert_similarity import HnswlibSimilarity class FastTestCase(unittest.TestCase): @@ -52,7 +52,7 @@ def test_annoy_model(self): corpus_new = [i + str(id) for id, i in enumerate(list_of_docs * 10)] m = AnnoySimilarity(corpus=list_of_docs * 10) print(m) - v = m._get_vector("This is test1") + v = m.get_embeddings("This is test1") print(v[:10], v.shape) print(m.similarity("This is a test1", "that is a test5")) print(m.distance("This is a test1", "that is a test5")) diff --git a/tests/test_imagesim.py b/tests/test_imagesim.py index 6d7fbac..4d0ac1a 100644 --- a/tests/test_imagesim.py +++ b/tests/test_imagesim.py @@ -11,7 +11,7 @@ sys.path.append('..') -from similarities.imagesim import ClipSimilarity, ImageHashSimilarity, SiftSimilarity +from similarities.image_similarity import ClipSimilarity, ImageHashSimilarity, SiftSimilarity pwd_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/tests/test_literalsim.py b/tests/test_literalsim.py index 621c5e0..5c2e452 100644 --- a/tests/test_literalsim.py +++ b/tests/test_literalsim.py @@ -9,7 +9,7 @@ sys.path.append('..') -from similarities.literalsim import ( +from similarities.literal_similarity import ( SimHashSimilarity, TfidfSimilarity, BM25Similarity, diff --git a/tests/test_sim_score.py b/tests/test_sim_score.py index e932393..4435b35 100644 --- a/tests/test_sim_score.py +++ b/tests/test_sim_score.py @@ -23,7 +23,7 @@ def test_sim_diff(self): self.assertTrue(abs(r - 0.4098) < 0.001) def test_empty(self): - v = m._get_vector("This is test1") + v = m.get_embeddings("This is test1") print(v[:10], v.shape) print(m.similarity("This is a test1", "that is a test5")) print(m.distance("This is a test1", "that is a test5"))