使用嵌入的建议
本文将介绍如何使用嵌入和最近邻搜索来进行推荐,以找到与给定文章最相似的其他文章。在这里,我们将使用AG新闻文章数据集作为示例。本文将引入以下几个方面:
- 引入必要的包和函数;
- 加载AG新闻数据集并对其进行预览;
- 通过嵌入计算文章之间的距离,并可视化距离;
- 使用最近邻搜索寻找与给定文章最相似的其他文章。
本文将演示如何使用嵌入来发现类似文章并进行推荐。
正文
推荐在网络上很普遍:
- ‘买了那个东西? 尝试这些类似的项目。
- 喜欢那本书吗? 试试这些类似的标题。
- ‘不是您正在寻找的帮助页面? 试试这些类似的页面。
此笔记本演示了如何使用嵌入来查找要推荐的相似项目。 特别是,我们使用 AG 的新闻文章语料库 作为我们的数据集。
我们的模型将回答这个问题:给定一篇文章,还有哪些其他文章与它最相似?
1.导入
首先,让我们导入稍后需要的包和函数。 如果您没有这些,则需要安装它们。 您可以通过运行 pip install {package_name}
通过终端安装它们,例如 pip install pandas
。
# imports import pandas as pd import pickle from openai.embeddings_utils import ( get_embedding, distances_from_embeddings, tsne_components_from_embeddings, chart_from_components, indices_of_nearest_neighbors_from_distances, ) # constants EMBEDDING_MODEL = "text-embedding-ada-002"
2.加载数据
接下来我们加载AG新闻数据,看看长什么样子。
# load data (full dataset available at http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html) dataset_path = "data/AG_news_samples.csv" df = pd.read_csv(dataset_path) # print dataframe n_examples = 5 df.head(n_examples)
标题 | 描述 | label_int | 标签 | |
---|---|---|---|---|
0 | 世界简报会 | 英国:布莱尔警告气候威胁 总理… | 1 | 世界 |
1 | 英伟达在主板上安装了防火墙(PC wo… | PC世界 – 即将推出的芯片组将包括… | 4 | 科学/技术 |
2 | 希腊文奥运喜悦,中国媒体 | 希腊的报纸反映了各种兴奋… | 2 | 体育 |
3 | U2 带图片的易拉罐 iPod | 加利福尼亚州圣何塞 – 苹果电脑(引用,Cha… | 4 | 科学/技术 |
4 | 梦工厂 | 任何产品,任何形状,任何尺寸 – 制造… | 4 | 科学/技术 |
让我们看一下这些相同的示例,但没有被省略号截断。
# 打印每个示例的标题、描述和标签 for idx, row in df.head(n_examples).iterrows(): print("") print(f"书名: {row['title']}") print(f"描述: {row['description']}") print(f"标签: {row['label']}")
书名:世界简报 描述:英国:布莱尔警告气候威胁 英国首相托尼·布莱尔敦促国际社会将全球变暖视为一个可怕的威胁,并商定一项行动计划以遏制"令人担忧的"气候变化。 温室气体的增长。 标签:世界 标题:Nvidia 在主板上安装防火墙(PC 世界) 描述:PC World - 即将推出的芯片组将为您的 PC 提供内置安全功能。 标签:科技 标题:希腊语中的奥林匹克喜悦,中文媒体 描述:希腊的报纸既对雅典奥运会取得成功感到欣喜,又对奥运会没有遭受任何重大挫折而松了一口气。 标签:运动 标题:U2 Can iPod with Pictures 描述:加利福尼亚州圣何塞——Apple Computer(报价、图表)推出了一批新的 iPod、iTunes 软件和宣传片,旨在使其在数字音乐播放器中保持领先地位。 标签:科技 书名:梦工厂 描述:任何产品,任何形状,任何尺寸——在您的桌面上制造! 未来是制造者。 来自 Wired 杂志的 Bruce Sterling。 标签:科技
3.构建缓存以保存嵌入
在获取这些文章的嵌入之前,让我们设置一个缓存来保存我们生成的嵌入。 通常,保存嵌入是个好主意,以便以后可以重新使用它们。 如果您不保存它们,则每次重新计算它们时都会再次付费。
缓存是一个字典,将 (text, model
) 的元组映射到嵌入,这是一个浮点数列表。 缓存保存为 Python pickle 文件。
# 建立嵌入缓存以避免重新计算 # 缓存是元组字典 (text, model) -> embedding,保存为 pickle 文件 # 设置嵌入缓存的路径 embedding_cache_path = "data/recommendations_embeddings_cache.pkl" # 加载缓存(如果存在),并将副本保存到磁盘 try: embedding_cache = pd.read_pickle(embedding_cache_path) except FileNotFoundError: embedding_cache = {} with open(embedding_cache_path, "wb") as embedding_cache_file: pickle.dump(embedding_cache, embedding_cache_file) # 定义一个函数以从缓存中检索嵌入(如果存在),否则通过 API 请求 def embedding_from_string( string: str, model: str = EMBEDDING_MODEL, embedding_cache=embedding_cache ) -> list: """返回给定字符串的嵌入,使用缓存避免重新计算。""" if (string, model) not in embedding_cache.keys(): embedding_cache[(string, model)] = get_embedding(string, model) with open(embedding_cache_path, "wb") as embedding_cache_file: pickle.dump(embedding_cache, embedding_cache_file) return embedding_cache[(string, model)]
让我们通过嵌入来检查它是否有效。
# 例如,取数据集中的第一个描述 example_string = df["description"].values[0] print(f"\nExample string: {example_string}") # print the first 10 dimensions of the embedding example_embedding = embedding_from_string(example_string) print(f"\n示例嵌入: {example_embedding[:10]}...")
示例字符串: 英国:布莱尔警告气候威胁 英国首相托尼·布莱尔敦促国际社会将全球变暖视为一个可怕的威胁,并就遏制"令人担忧的"气候变化的行动计划达成一致。 温室气体的增长。 Example embedding: [-0.01071077398955822, -0.022362446412444115, -0.00883542187511921, -0.0254171434789896, 0.031423427164554596, 0.010723662562668324, -0.016717055812478065, 0.004195375367999077, -0.008074969984591007, -0.02142154797911644]...
4.基于embeddings推荐相似文章
要查找类似的文章,让我们遵循一个三步计划:
- 获取所有文章描述的相似度嵌入
- 计算来源标题与所有其他文章之间的距离
- 打印出最接近源标题的其他文章
def print_recommendations_from_strings( strings: list[str], index_of_source_string: int, k_nearest_neighbors: int = 1, model=EMBEDDING_MODEL, ) -> list[int]: """打印出给定字符串的 k 个最近邻居。""" # 获取所有字符串的嵌入 embeddings = [embedding_from_string(string, model=model) for string in strings] # 获取源字符串的嵌入 query_embedding = embeddings[index_of_source_string] # 获取源嵌入和其他嵌入之间的距离(function from embeddings_utils.py) distances = distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine") # 获取最近邻居的索引 (function from embeddings_utils.py) indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances) # 打印出源字符串 query_string = strings[index_of_source_string] print(f"源字符串: {query_string}") # 打印出它的k个最近邻居 k_counter = 0 for i in indices_of_nearest_neighbors: # 跳过与起始字符串完全匹配的任何字符串 if query_string == strings[i]: continue # 打印出k篇文章后停止 if k_counter >= k_nearest_neighbors: break k_counter += 1 # 打印出相似的字符串及其距离 print( f""" --- Recommendation #{k_counter} (nearest neighbor {k_counter} of {k_nearest_neighbors}) --- String: {strings[i]} Distance: {distances[i]:0.3f}""" ) return indices_of_nearest_neighbors
5.示例建议
让我们寻找与第一篇相似的文章,那篇文章是关于托尼·布莱尔的。
article_descriptions = df["description"].tolist() tony_blair_articles = print_recommendations_from_strings( strings=article_descriptions, # 让我们根据文章描述建立相似性 index_of_source_string=0, # 让我们看看类似于第一篇关于托尼·布莱尔的文章 k_nearest_neighbors=5, # 让我们看看最相似的 5 篇文章 )
来源字符串:英国:布莱尔警告气候威胁 英国首相托尼·布莱尔敦促国际社会将全球变暖视为一个可怕的威胁,并就遏制"令人担忧的"气候变化的行动计划达成一致。 温室气体的增长。 --- 建议 #1(最近邻居 5 个中的第一个)--- 字符串:首相约翰霍华德今天表示,英国首相托尼布莱尔的连任将被视为对伊拉克军事行动的认可。 距离:0.153 --- 建议 #2(最近邻 2,共 5 个)--- 字符串:英国伦敦——据报道,一位美国科学家观察到主要温室气体二氧化碳的数量出现了惊人的跳跃。 距离:0.160 --- 建议 #3(最近邻 3,共 5 个)--- 字符串:今天,在伊拉克,人质肯尼思·比格利 (Kenneth Bigley) 的痛苦笼罩着英国首相托尼·布莱尔 (Tony Blair),他面临着地方选举和工党就分裂战争展开辩论的双重考验。 距离:0.160 --- 建议 #4(最近邻 4,共 5 个)--- 字符串:以色列准备支持明年初由托尼·布莱尔召集的中东会议,尽管它表示担心英国的计划过于雄心勃勃且设计过头 距离:0.171 --- 建议 #5(最近邻 5 的 5)--- 字符串:法新社 - 一支英国军队战斗群从伊拉克南部撤出,执行美国要求的任务,前往巴格达附近更致命的地区,这是英国首相托尼布莱尔的一场重大政治赌博。 距离:0.173
不错! 5 条建议中有 4 条明确提到了托尼·布莱尔,第五条是来自伦敦的一篇关于气候变化的文章,这些话题可能经常与托尼·布莱尔相关。
让我们看看我们的推荐器在第二篇关于 NVIDIA 的新芯片组的安全性更高的示例文章中是如何做的。
chipset_security_articles = print_recommendations_from_strings( strings=article_descriptions, # 让我们根据文章描述建立相似性 index_of_source_string=1, # 让我们看看类似于第二篇关于更安全芯片组的文章 k_nearest_neighbors=5, # 让我们看看最相似的 5 篇文章 )
来源字符串:PC World - 即将推出的芯片组将为您的 PC 提供内置安全功能。 --- 建议 #1(最近邻居 5 个中的第一个)--- 字符串:PC World - 为企业更新的防病毒软件增加了入侵防御功能。 距离:0.112 --- 建议 #2(最近邻 2,共 5 个)--- 字符串:PC World - 曾经的年度世界级产品 PDA 获得了急需的升级。 距离:0.145 --- 建议 #3(最近邻 3,共 5 个)--- 字符串:PC World - 使用新的网关和媒体适配器以无线方式将您的视频发送到您的整个房间。 距离:0.153 --- 建议 #4(最近邻 4,共 5 个)--- 字符串:PC World - 赛门铁克和 McAfee 希望提高病毒定义费能将用户转移到套件。 距离:0.157 --- 建议 #5(最近邻 5 的 5)--- 字符串:Gateway 计算机将在 Office Depot 更广泛地销售,这是 PC 制造商 #39;自今年收购竞争对手 eMachines 以来扩大零售店分销的最新举措。 距离:0.168
从打印的距离,您可以看到 #1 推荐比所有其他推荐更接近(0.11 对 0.14+)。 #1 推荐看起来与起始文章非常相似 – 这是 PC World 的另一篇关于提高计算机安全性的文章。 不错!
附录:在更复杂的推荐系统中使用嵌入
构建推荐系统的一种更复杂的方法是训练一个机器学习模型,该模型接收数十或数百个信号,例如项目受欢迎程度或用户点击数据。 即使在这个系统中,嵌入对于推荐系统来说也是一个非常有用的信号,特别是对于那些还没有用户数据的“冷启动”项目(例如,一个全新的产品添加到目录中还没有任何点击)。
附录:使用嵌入可视化相似文章
为了了解我们最近的邻居推荐器在做什么,让我们可视化文章嵌入。 虽然我们无法绘制每个嵌入向量的 2048 维,但我们可以使用 t-SNE 或 PCA 等技术将嵌入压缩为 2 或 3 维,我们可以绘制图表。
在可视化最近邻之前,让我们使用 t-SNE 可视化所有文章描述。 请注意,t-SNE 不是确定性的,这意味着结果可能因运行而异。
# 获取所有文章描述的嵌入 embeddings = [embedding_from_string(string) for string in article_descriptions] # 使用 t-SNE 将 2048 维嵌入压缩为二维 tsne_components = tsne_components_from_embeddings(embeddings) # 获取为图表着色的文章标签 labels = df["label"].tolist() chart_from_components( components=tsne_components, labels=labels, strings=article_descriptions, width=600, height=500, title="文章描述的 t-SNE 组件", )
正如您在上图中所见,即使是高度压缩的嵌入也能很好地按类别对文章描述进行聚类。 值得强调的是:这种聚类是在不知道标签本身的情况下完成的!
此外,如果您仔细观察最令人震惊的异常值,它们通常是由于错误标记而不是嵌入不良造成的。 例如,绿色体育集群中的大部分蓝色世界点似乎都是体育故事。
接下来,让我们根据它们是源文章、最近的邻居还是其他来重新着色这些点。
# 为推荐文章创建标签 def nearest_neighbor_labels( list_of_indices: list[int], k_nearest_neighbors: int = 5 ) -> list[str]: """返回标签列表以对 k 个最近的邻居进行着色。""" labels = ["Other" for _ in list_of_indices] source_index = list_of_indices[0] labels[source_index] = "Source" for i in range(k_nearest_neighbors): nearest_neighbor_index = list_of_indices[i + 1] labels[nearest_neighbor_index] = f"Nearest neighbor (top {k_nearest_neighbors})" return labels tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5) chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5 )
# Tony Blair 文章最近邻的二维图表 chart_from_components( components=tsne_components, labels=tony_blair_labels, strings=article_descriptions, width=600, height=500, title="托尼布莱尔文章的最近邻居", category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]}, )
查看上面的二维图表,我们可以看到有关托尼·布莱尔的文章在世界新闻集群中有些靠得很近。 有趣的是,虽然 5 个最近的邻居(红色)在高维空间中最近,但它们并不是这个压缩二维空间中的最近点。 将嵌入压缩到 2 维会丢弃它们的大部分信息,并且 2D 空间中的最近邻似乎不如完整嵌入空间中的那些相关。
# 芯片组安全文章的最近邻二维图 chart_from_components( components=tsne_components, labels=chipset_security_labels, strings=article_descriptions, width=600, height=500, title="Nearest neighbors of the chipset security article", category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]}, )
对于芯片组安全性示例,完整嵌入空间中的 4 个最近邻在这个压缩的 2D 可视化中仍然是最近邻。 第五个显示为更远,尽管在完整嵌入空间中更近。
如果需要,您还可以使用函数 chart_from_components_3D
制作嵌入的交互式 3D 图。 (这样做将需要使用 n_components=3
重新计算 t-SNE 组件。)
评论 (0)