基于Chinese-CLIP与ChromaDB的中文图像检索功能实现

发布于:2025-07-15 ⋅ 阅读:(13) ⋅ 点赞:(0)

本文按“原理 → 代码 → 讲解”三层展开,读者只需具备 Python 基础即可跟随完成一个可落地的以文搜图应用。

一、整体思路

  1. 把图片和文字都转成固定长度的向量(768 维)。
  2. 把图片向量提前存入向量数据库。
  3. 查询时把文字转成向量,再找出最相似的图片向量。

实现依赖两个核心组件:

  • Chinese-CLIP:中文多模态模型,负责向量化。
  • ChromaDB:轻量级向量数据库,负责存储与检索。

二、准备工作

软件

  • Python ~=3.10
  • 显卡可选,如有 NVIDIA GPU 请提前装好 CUDA 驱动。

安装依赖

pip install torch 
pip install transformers  chromadb pillow numpy

数据
在任意位置新建文件夹,例如 D:/photos,把待检索的 .jpg.png 图片全部放进去。

三、分步实现

1. 载入模型

使用transformers 加载需要的模型。

from transformers import ChineseCLIPModel, ChineseCLIPProcessor
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "OFA-Sys/chinese-clip-vit-large-patch14-336px"
model = ChineseCLIPModel.from_pretrained(model_name).to(device)
processor = ChineseCLIPProcessor.from_pretrained(model_name)

要点

  • processor 负责把图片或文本转成模型所需的张量。
  • 首次运行会自动下载约 1 GB 权重,后续离线可用。

2. 图片预处理

模型输入要求 336×336 像素,并保持原始比例居中裁剪。

from PIL import Image

def load_image(image_path: str, out_size=(336, 336)) -> Image.Image:
    target_w, target_h = out_size
    with Image.open(image_path) as img:
        img = img.convert("RGB")
        ow, oh = img.size
        scale = max(target_w / ow, target_h / oh)
        new_w, new_h = int(ow * scale + 0.5), int(oh * scale + 0.5)
        img = img.resize((new_w, new_h), Image.LANCZOS)
        left = (new_w - target_w) // 2
        top = (new_h - target_h) // 2
        img = img.crop((left, top, left + target_w, top + target_h))
        return img

3. 特征提取

把图片或文本变成 768 维向量,并对向量做 L2 归一化,使后续相似度计算简化为点积。

import numpy as np

def images_to_vectors(images):
    inputs = processor(images=images, return_tensors="pt").to(device)
    with torch.no_grad():
        vec = model.get_image_features(**inputs)
    vec = vec / vec.norm(p=2, dim=-1, keepdim=True)
    return vec.cpu().numpy()

def texts_to_vectors(texts):
    inputs = processor(text=texts, padding=True, return_tensors="pt").to(device)
    with torch.no_grad():
        vec = model.get_text_features(**inputs)
    vec = vec / vec.norm(p=2, dim=-1, keepdim=True)
    return vec.cpu().numpy()

4. 构建向量数据库

ChromaDB 会在本地目录保存数据,支持增量写入。

import chromadb, uuid
from pathlib import Path

DATA_DIR = Path("D:/photos")
DB_PATH  = "images.chroma_db"

def build_database(data_dir=DATA_DIR):
    client = chromadb.PersistentClient(DB_PATH)
    collection = client.get_or_create_collection(
        name="photos",
        metadata={"hnsw:space": "cosine"}
    )

    existing = set(collection.get()["uris"])     # 已入库图片
    paths = [p for p in data_dir.rglob("*.jpg") if str(p) not in existing]

    if not paths:
        print("没有新图片需要入库")
        return

    batch_size = 32
    for i in range(0, len(paths), batch_size):
        batch_paths = paths[i:i+batch_size]
        images = [load_image(p) for p in batch_paths]
        vectors = images_to_vectors(images)

        ids  = [str(uuid.uuid4()) for _ in batch_paths]
        uris = [str(p) for p in batch_paths]

        collection.add(embeddings=vectors.tolist(), ids=ids, uris=uris)
        print(f"已入库 {len(batch_paths)} 张")

运行一次即可:

build_database()

5. 文字查询

def search(text, top_k=5):
    client = chromadb.PersistentClient(DB_PATH)
    collection = client.get_collection("photos")

    vec = texts_to_vectors([text])[0]
    hits = collection.query(
        query_embeddings=[vec.tolist()],
        n_results=top_k,
        include=["uris", "distances"]
    )
    return list(zip(hits["uris"][0], hits["distances"][0]))

返回示例

results = search("海棠", top_k=5)
for path, dist in results:
    print(f"{dist:.3f}  {path}")

6. 结果可视化(可选)

import matplotlib.pyplot as plt

def show_results(results, cols=5):
    n = len(results)
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols*2.5, rows*2.5))
    axes = axes.flatten() if n > 1 else [axes]
    for ax, (path, dist) in zip(axes, results):
        img = Image.open(path)
        ax.imshow(img)
        ax.set_title(f"{dist:.2f}")
        ax.axis("off")
    plt.tight_layout(); plt.show()

调用:

results = search("海棠", top_k=10)
show_results(results)

四、常见问题

  • CPU 运行慢?单张图片约 200 ms;GPU 可降至 50 ms。
  • 内存不足?把 batch_size 降到 8 或更小。
  • 想支持更多格式?把 rglob("*.jpg") 改成 rglob("*") 并自行过滤扩展名。

五、下一步可扩展

  1. 混合查询:同时输入文字 + 参考图片,把两个向量平均后再搜索。
  2. 过滤条件:在 collection.add 时附加元数据(时间、标签),查询时加 where 条件。
  3. 分布式部署:把 ChromaDB 换成 Milvus / Weaviate,即可横向扩展。

至此,你已拥有一个完整、可维护、易扩展的中文图文检索应用。


网站公告

今日签到

点亮在社区的每一天
去签到