复现|PlugIR:基于对话的图像检索系统
论文:Interactive Text-to-ImageRetrieval with Large Language Models: A Plug-and-Play Approach
1. Introduction
ChatIR是一种基于聊天的文本到图像的方法,提出了利用LLM进行多轮对话以提高检索效率。(我的阅读笔记见复现|ChatIR)
这个方法具有以下不足:
- 需要进行微调文本编码器,来适应多轮对话数据
- 微调耗费资源、可扩展性差。
- 解决方法:将对话重构为可以直接输入到预训练的视觉语言模型的格式,不需要对模型进行微调
- LLM提问者G只知道对话历史,无法查看候选图像
- 可能生成图像中不存在的属性的询问
- LLM提问者基于候选集提问,确保问题与图像属性相关
本文提出的PlugIR系统包含2个部分:上下文重构(Context Reformulation)和上下文感知对话生成(Context-aware Dialogue Generation)
本文贡献:
- 实证0样本或微调的大模型难以理解对话数据
- 提出了一种LLM提问者,解决了因冗余问题和噪声问题带来的性能瓶颈
- 提出新指标BRI(最佳对数排名积分,Best log Rank Integral)
- 比Recall@K和Hit@K更接近人的评价,更全面地评估交互式检索系统
- PlugIR具有即插即用的特性,且具有实用性
2. Related Work
- 文本到图像检索
- 视觉语言模型:BLIP、CLIP
- 大语言模型LLM
3. Method
3.1 Preliminaries:InteractiveText-to-Image Retrieval 交互式文本到图像检索
对话记录表示:$D_i = (C, Q_1, A_1, …, Q_i)$
- $C$:目标图像的初始文本描述(标题)
- $Q_i$:第i个问题
- $A_i$:第i个回答
检索系统将数据库中的所有图片与文本进行匹配,根据相似度进行排序,根据目标排名评估系统性能。
- Recall@K:本轮交互检索到的前K张图片包含目标的概率
- Hit@K:本轮以及任意一轮的交互检索到的前K张图片包含目标的概率
3.2 Context Reformulation 上下文重构
作者测试了0样本的CLIP、BLIP、BLIP-2和一个黑箱模型ATM
- Hit@K逐步提升,但这是由其定义决定的
- Recall@K在仅包含最初的文本描述时最高,随着对话轮次增加而下降
- 对话在0样本模型上可能没有贡献、产生了噪声
- 0样本模型无法理解对话数据
为了解决这个问题,一种方法是像ChatIR一样对模型进行微调,但这样做有以下限制:
- 不能使用黑箱模型,比如ATM
- 需要大量的训练数据
本文不直接使用对话作为输入进行查询,而是将对话重构为可以直接输入到预训练的视觉语言模型的格式,不需要对模型进行微调(即所谓的Plug-and-Play)。
3.3 Context-aware Dialogue Generation 上下文感知对话生成
仅靠对话历史生成问题具有以下问题:
- 生成的问题可能与图像属性无关
- 可能询问历史对话中已有信息
提问过程(用于解决问题1):
- 使用重构后的查询语句进行检索,找出高相似度的“检索候选”图像集
- 对候选图像Embeddings进行K-means聚类,得到每个候选图像与其他图像的相似度得分分布
- 对于每个聚类,选择
相似度分布熵
最小的图像作为代表- 熵越小,属性越真实、越容易区分
- 例如,同一组图像对“一张配有2台电脑显示器和一副键盘的桌子”的描述熵更低,对“办公室”的描述熵更高
- 将这K副图像通过image2text模型生成caption,作为附加信息提供给LLM提问者
提问(算法1)伪代码:
- 输入:对话上下文$c$、图像库$I$、“检索候选”图像数$n$、聚类数$m$、相似度函数$sim$、$KMeans$、i2t模型$Captioning$
- 从$I$中选出前$n$个和$c$最相似的图像,作为$S_R$
- 初始化$S_R \leftarrow {}$
- $while S_R.size() < n do$
- 将和$c$最相似的图像$x$加入$S_R$
- 将$x$从$I$中移除
- 对$S_R$进行$KMeans$聚类,得到$m$个聚类$S_R^{(1)}, S_R^{(2)}, …, S_R^{(m)}$
- 计算每个图像相对$S_R$的概率,使用Softmax得到$P_c(x)=\frac{exp(sim(c, x))}{\sum_{x’ \in S_R} exp(sim(c, x’))}$
- 从每个簇$S_R^{(i)}$中选择最优的图像,并对这$m$个图像进行$Captioning$,得到$T$
- $for i in range(1,m+1) do$
- 计算当前簇$S_R^{(i)}$中所有图像的熵,并找出最小熵的图像$\hat x^{(i)}$
- 对$\hat x^{(i)}$进行$Captioning$,并加入$T$
- $for i in range(1,m+1) do$
- 返回:$T$
采用思维链(Chain of Thought)的方法,提示词位于原文18~19页,获取与图像相关的问题。
这样生成的问题仍然可能冗余(已经知道答案),还需要经过过滤。
过滤过程(用于解决问题2):
- 通过上下文回答函数,判断问题是否“确定”,选取“不确定”的问题
- 选择“不确定”的问题中KL散度最小的问题
- KL散度:$KL(P_c||P_{c,q})=\sum_{x \in T} P_c(x)log\frac{P_c(x)}{P_{c,q}(x)}$
- 用于防止不合适的问题导致相似度骤变
过滤(算法2)伪代码:
- 输入:对话上下文$c$、问题集合$Q$、检索候选集$T$、相似度函数$sim$、上下文回答函数$Answer$
- 定义计算上下文概率分布的函数
- 图像$x$在上下文$c$下的分布:$P_c(x)=\frac{exp(sim(c, x))}{\sum_{x’ \in T} exp(sim(c, x’))}$
- 加入问题$q$后图像$x$在上下文$c$下的分布:$P_{c,q}(x)=\frac{exp(sim(concat(c, q), x))}{\sum_{x’ \in T} exp(sim(concat(c, q), x’))}$
- 筛选出答案“不确定”的问题,作为$Q’$
- 初始化$Q’ \leftarrow {}$
- $for q in Q do$
- 如果$Answer(c, q)$为“不确定”,则加入$Q’$
- 选择KL散度最小的问题$\hat q$
- 返回:$\hat q$
3.4 The Best Log Rank Integral (BRI) Metric 最佳对数排名积分
作者指出,在评估交互式检索系统时,有3个关键点:
- 用户满意度:在多少次交互中至少找到了一次目标图像算满意
- 效率:成功检索所需轮次越少越好
- 排名提升意义:排名靠前时提升排名的意义更大,如从2到1比从100到99更有意义
Recall@K用于非交互式检索;Hit@K只考虑了用户满意度
作者提出了BRI指标,综合了用户满意度、效率和排名提升意义
记:$Q$问题集合、$T$最大轮次
$\pi(q_t)$:表示具有$t$轮对话的查询$q_t$,在这$t$轮查询中,目标图像的历史最佳排名,用于衡量用户满意度
BRI:$\mathbb E_{q \in Q}\left[ \dfrac{1}{2T}\log\pi(q_0)\pi(q_T)+\dfrac{1}{T}\sum\limits_{t=1}^{T-1}\log\pi(q_t)\right]$
- 边界项:$\dfrac{1}{2T}\log\pi(q_0)\pi(q_T)$
- 权重较小,反映初始查询$q_0$到最终查询$q_T$的排名改善情况
- 平均查询排名项:$\dfrac{1}{T}\sum\limits_{t=1}^{T-1}\log\pi(q_t)$
- 计算了查询$q$的所有$t$轮中,目标的历史最佳排名的对数的均值
- 对数函数使得低排名的进一步降低对BRI变化的影响更大
- BRI越小,性能越好
- BRI不依赖于具体的K值,更全面、统一
- 实验表明,BRI与人类评价更接近
4. Experiments
- 数据集:Visdail、COCO、Flickr30k
- 文本到图像检索模型:默认BLIP,也有BLIP-2、ATM
- LLM提问者:ChatGPT
- 测试集回答者:BLIP-2
- 聚类数m:10
Baseline:0-shot、ChatIR
同时进行了Ablation Study,测试了不同组件的加入对结果的影响
总结和实现
- PlugIR系统也是一个基于对话的图像检索系统,在ChatIR的基础上进行了改进
- 主要优化了提问过程,使得提问的有效性提升
- 使用了新的评估指标BRI,能够更全面地评估交互式检索系统
由于系统代码量较大,在实现时划分到多个文件
config.py
1 | import torch |
OpenAI.py
所有函数统一返回response对象,包含了所有信息。
messages.py中定义了消息格式,包含prompt信息,具体文本见论文附录。
1 | import openai |
utils.py
实现了特征提取、K-means聚类、KL散度计算、获取簇中心caption、熵计算等函数功能,使主程序代码简洁易读。
1 | import torch |
系统实现:PlugIR_exec.py
PlugIR的运行版,实现利用多轮对话进行图像检索的功能,描述以及每次问答后显示当前最相关的图片。
改写为PlugIR_func.py
,实现了函数化,便于后续批量生成对话数据用于evaluation。
1 | import torch |
运行效果:
对话数据生成:test_gen.py
为了自动获取测试数据,免除人工回答,使用了BLIP2模型回答问题。
1 | import torch |
Debug日志
测试使用了项目仓库的eval.py
源代码。
可能是我使用Windows系统的缘故,eval.py
代码中有一些报错,具体问题和调整如下:
AttributeError: Can't pickle local object 'BLIP_ZERO_SHOT_BASELINE.<locals>.<lambda>'
- 原因:在Windows上使用多进程(
num_workers>0
)时,需要pickle对象,但是其中的lambda函数或局部函数不能被pickle - 解决:将lambda函数改为全局函数,再使用
functools.partial
进行参数绑定
- 原因:在Windows上使用多进程(
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
- 原因:函数
text_encode_fn
中的processor
没有to(device)
- 解决:多做一步,将
processor
的输出``to(device)`
- 原因:函数
尝试跑通generate_dialog.py
的过程中也遇到了上述问题,解决方法同。
测试结果
正常情况下3~5分钟可以生成1条数据(大约80次请求)。
但由于我生成数据短时间内大量调用OpenAI API,导致被限流,20~30分钟才能产生1条数据,故本次实现在eval阶段仅有253条数据。
在使用ChatGPT-4o-mini作为提问模型G,BLIP2作为回答模型A,使用同一测试代码的情况下,测试结果和仓库的对话数据Hit@K对比如下:
length | 仓库数据Hit@10(2064 testcases) | 实现数据Hit@10(253 testcases) |
---|---|---|
0 | 71.12% | 72.33% |
1 | 79.02% | 81.42% |
2 | 83.09% | 83.40% |
3 | 85.85% | 85.38% |
4 | 87.55% | 86.56% |
5 | 88.71% | 87.75% |
6 | 89.39% | 88.14% |
7 | 90.12% | 88.54% |
8 | 90.70% | 88.93% |
9 | 91.09% | 89.33% |
10 | 91.47% | 90.12% |
BRI对比(越低越好):
- 仓库对话的BRI:10.195615768432617
- 实现对话的BRI:10.252569198608398