简述
在自然语言处理(NLP)领域,Prompt工程是提升大语言模型(LLM)性能的重要技术。本文分析了一段用于优化服务提供商支持工单根因分析的Python代码,展示了如何通过自动化迭代优化Prompt来提高模型输出的准确性与相关性。本文将从技术背景、创新点、取得的效果、代码优点来说明这一实用的工程化手段
技术背景
Prompt优化是近年来随着大型语言模型的广泛应用而兴起的一项技术。传统的NLP任务通常需要对模型进行微调(fine-tuning),但这需要大量标注数据和计算资源。而Prompt工程通过设计高质量的输入指令(Prompt),引导模型生成更符合预期的输出,显著降低了开发成本。文中代码的目标是优化一个用于总结工单根因的Prompt,通过结合句向量余弦相似度和链式推理(Chain-of-Thought, CoT)生成新的Prompt变异,逐步提升模型性能。
该代码依赖以下关键技术:
SentenceTransformer:用于生成文本嵌入,计算预测总结与参考总结的余弦相似度,作为奖励函数。
OpenAI API:调用gpt-4o-mini模型,生成总结或新的Prompt变异。
日志与进度管理:通过logging模块和JSONL文件记录优化过程,便于调试和恢复。
自动化Prompt变异:通过CoT生成新的Prompt变异,结合随机选择和评估策略优化Prompt。
创新点
自动化Prompt优化框架:
代码通过迭代生成和评估Prompt变异,实现了自动化的Prompt优化流程,避免了人工设计的低效性。使用CoT提示模型分析当前Prompt的不足并生成变异提示,增强了优化的针对性。结合句向量与余弦相似度:
采用SentenceTransformer计算预测总结与参考总结的余弦相似度,作为奖励函数,量化Prompt性能。这种方法相比传统的手动评估更加客观且可扩展。鲁棒的错误处理与进度保存:
代码在API调用、数据加载和Prompt变异等环节加入了多重错误处理机制(如指数退避重试),确保稳定性。使用JSONL格式保存每一步的优化进度,支持中断后恢复,适合长时间运行的优化任务。早停机制与动态调整:
引入早停机制(max_no_improvement=3),在连续三次无改进时停止优化,节省计算资源。根据前一步是否改进动态选择是否调用CoT生成新变异,优化效率。
prompt优化代码
import json
import random
import re
import requests
import time
import os
import logging
from functools import lru_cache
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
# === 配置日志 ===
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# === 模型调用配置 ===
base_url = "https:/xx
api_key = os.getenv("OPENAI_API_KEY", "xx") # Use env variable for security
model = "gpt-4o-mini" # Match deployment name
# === 初始化 ===
try:
embedding_model = SentenceTransformer("../model/all-MiniLM-L6-v2")
logger.info("SentenceTransformer loaded successfully from local path")
except Exception as e:
logger.error(f"Failed to load SentenceTransformer: {e}")
logger.info("Please ensure the model is available at '../model/all-MiniLM-L6-v2'")
raise
data_file = "final_tickets.jsonl"
progress_file = "prompt_progress.jsonl" # JSONL for progress
# === 保存进度 ===
def save_progress(best_prompt, best_score, step):
try:
with open(progress_file, "a", encoding="utf-8") as f: # Append mode
json.dump({"step": step, "best_prompt": best_prompt, "best_score": best_score}, f)
f.write("\n") # Add newline for JSONL format
f.flush() # Ensure immediate write
logger.info(f"Appended progress to {progress_file} for step {step}")
except Exception as e:
logger.error(f"Error saving progress: {e}")
# === 加载数据 ===
def load_data(file_path, limit=None):
try:
with open(file_path, "r", encoding="utf-8") as f:
lines = [json.loads(line) for line in f]
if limit:
lines = lines[:limit]
random.shuffle(lines) # Shuffle for random train/validation split
logger.info(f"Loaded {len(lines)} tickets from {file_path}")
return lines
except Exception as e:
logger.error(f"Error loading data from {file_path}: {e}")
raise
# Split dataset into training and validation
try:
dataset = load_data(data_file, limit=50)
train_data = dataset[:40] # 80% for training
val_data = dataset[40:] # 20% for validation
except Exception as e:
logger.error(f"Failed to split dataset: {e}")
raise
# === 模型调用 ===
@lru_cache(maxsize=2000) # Increased cache size
def call_model(prompt: str, ticket_content: str, retries=10):
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
messages = [
{"role": "system", "content": "You are a professional support engineer."},
{"role": "user", "content":
prompt + "\n\n ticket content:" + ticket_content}
]
payload = {
"model": model,
"messages": messages,
"temperature": 0.3
}
for attempt in range(retries):
try:
response = requests.post(base_url, headers=headers, json=payload, timeout=30)
response.raise_for_status()
logger.info(f"Successful API call for prompt: {prompt[:50]}...")
return response.json()["choices"][0]["message"]["content"].strip()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 429:
wait_time = min(15 * (2 ** attempt), 120) # Exponential backoff for 429: 15s, 30s, 60s, 120s
logger.warning(f"Rate limit hit (attempt {attempt + 1}/{retries}): Waiting {wait_time}s")
time.sleep(wait_time)
else:
logger.error(f"HTTP error (attempt {attempt + 1}/{retries}): {e}")
time.sleep(2 ** attempt) # General error backoff: 1s, 2s, 4s
except Exception as e:
logger.error(f"Error calling model (attempt {attempt + 1}/{retries}): {e}")
time.sleep(2 ** attempt)
logger.error("Failed to get response after retries")
return ""
# === 生成新的 Prompt 变异 ===
def generate_prompt_variations(best_prompt, best_score, sample_ticket, candidate_score, retries=3):
coT_prompt = (
f"You are a prompt engineering expert. I’m optimizing a prompt for summarizing the root cause of support tickets. "
f"The current prompt is: \"{best_prompt}\". Here’s a sample ticket: \"{sample_ticket[:500]}\". "
f"The prompt’s average similarity score (cosine similarity of sentence embeddings) is {best_score:.4f}. "
f"The latest candidate prompt scored {candidate_score:.4f}, which did not improve over the current best. following below requirements"
f"1. Analyze why the current prompt and recent candidate may not be yielding higher similarity scores.\n"
f"2. Suggest 3 new prompt variations (either rephrasings or additions) to improve clarity, conciseness, or relevance.\n"
f"3. For each variation, explain why it might improve performance.\n"
f"Return the response in JSON format:\n"
f"{{\n \"analysis\": \"...\",\n \"variations\": [\n {{\"type\": \"append|replace\", \"prompt\": \"...\", \"reason\": \"...\"}},\n ...\n ]\n}}"
)
try:
print('coT_prompt', coT_prompt)
response = call_model(coT_prompt, "", retries=retries)
print('response', response)
response = re.search(r"```json(.*?)```j", response, re.DOTALL)
print('response2', response)
result = json.loads(response)
logger.info("Generated new prompt variations via CoT")
variations = result.get("variations")
# Ensure variations have the correct structure
valid_variations = [
(v["type"], v["prompt"]) for v in variations
if isinstance(v, dict) and v.get("type") in ["append", "replace"] and v.get("prompt")
]
return valid_variations
except Exception as e:
logger.error(f"Error generating prompt variations: {e}")
logger.info("Falling back to hardcoded variations")
return None
# === 奖励函数:句向量余弦相似度 ===
def compute_reward(predicted: str, reference: str):
if not predicted or not reference:
logger.warning("Empty predicted or reference summary; returning score 0.0")
return 0.0
try:
emb_pred = embedding_model.encode(predicted, convert_to_tensor=True)
emb_ref = embedding_model.encode(reference, convert_to_tensor=True)
score = util.pytorch_cos_sim(emb_pred, emb_ref).item()
return score
except Exception as e:
logger.error(f"Error computing reward: {e}")
return 0.0
# === Prompt 变异策略 ===
def mutate_prompt(prompt, sample_ticket, best_score, candidate_score):
try:
variations = generate_prompt_variations(prompt, best_score, sample_ticket, candidate_score)
mutation_type, variation = random.choice(variations)
if mutation_type == "append":
new_prompt = prompt + " " + variation
elif mutation_type == "replace":
new_prompt = variation
else:
new_prompt = prompt
logger.info(f"Mutated prompt: {new_prompt[:50]}...")
return new_prompt
except Exception as e:
logger.error(f"Error mutating prompt: {e}")
return prompt
# === 评估当前 prompt ===
def evaluate_prompt(prompt, dataset, subset_size=3): # Reduced to 3 to lower request rate
try:
dataset_subset = random.sample(dataset, min(subset_size, len(dataset)))
scores = []
for item in tqdm(dataset_subset, desc="Evaluating prompt"):
input_text = item.get("ticket_content", "")
reference_summary = item.get("root_cause_summary", "")
if not input_text or not reference_summary:
logger.warning("Skipping ticket with missing content or summary")
continue
try:
predicted_summary = call_model(prompt, input_text)
score = compute_reward(predicted_summary, reference_summary)
scores.append(score)
except Exception as e:
logger.error(f"Error processing ticket: {e}")
continue
time.sleep(5) # Increased delay to avoid rate limits
avg_score = sum(scores) / len(scores) if scores else 0.0
logger.info(f"Average score for prompt: {avg_score:.4f}")
return avg_score
except Exception as e:
logger.error(f"Error evaluating prompt: {e}")
return 0.0
# === 初始 prompt ===
initial_prompt = (
"You are a support engineer. Given the following ticket, "
"summarize in the shortest possible sentence, beginning with \"The root cause of the issue is...\""
)
# === 主循环:Prompt 优化 ===
try:
best_prompt = initial_prompt
best_score = evaluate_prompt(best_prompt, train_data)
no_improvement_count = 0
max_no_improvement = 3 # Early stopping threshold
num_candidates = 1 # Reduced to 1 to lower request rate
# Select a sample ticket for CoT variations
sample_ticket = train_data[0].get("ticket_content", "") if train_data else ""
for step in range(10): # Max iterations
logger.info(f"\n🔄 Step {step + 1}")
try:
# Generate candidate prompt (use CoT if no improvement in previous step)
if no_improvement_count > 0:
logger.info("No improvement in previous step; generating new variations via CoT")
candidate_prompts = [mutate_prompt(best_prompt, sample_ticket, best_score, candidate_scores[0] if candidate_scores else best_score)]
else:
candidate_prompts = [mutate_prompt(best_prompt, sample_ticket, best_score, best_score)]
candidate_scores = [evaluate_prompt(p, train_data) for p in candidate_prompts]
max_score = max(candidate_scores) if candidate_scores else 0.0
max_idx = candidate_scores.index(max_score) if candidate_scores else 0
logger.info(f"Candidate Scores: {[f'{s:.4f}' for s in candidate_scores]}")
if max_score > best_score:
logger.info(f"✅ Improved prompt! New Score: {max_score:.4f}")
best_prompt = candidate_prompts[max_idx]
best_score = max_score
no_improvement_count = 0
else:
logger.info(f"❌ No improvement (Best Candidate Score: {max_score:.4f})")
no_improvement_count += 1
logger.info(f"Current best_prompt: {best_prompt[:50]}...")
save_progress(best_prompt, best_score, step)
except Exception as e:
logger.error(f"Error in step {step + 1}: {e}")
continue
if no_improvement_count >= max_no_improvement:
logger.info("🛑 Early stopping: No improvement for 3 consecutive steps")
break
# === 验证最终 prompt ===
val_score = evaluate_prompt(best_prompt, val_data, subset_size=len(val_data))
logger.info("\n🏁 Final Optimized Prompt:")
logger.info(best_prompt)
logger.info(f"🏆 Best Training Score: {best_score:.4f}")
logger.info(f"📊 Validation Score: {val_score:.4f}")
except KeyboardInterrupt:
logger.info("KeyboardInterrupt detected; saving progress and exiting")
save_progress(best_prompt, best_score, step)
logger.info("\n🏁 Partial Results:")
logger.info(f"Best Prompt: {best_prompt}")
logger.info(f"Best Training Score: {best_score:.4f}")
exit(0)
except Exception as e:
logger.error(f"Unexpected error: {e}")
save_progress(best_prompt, best_score, step)
raise
运行
据代码日志输出,优化过程在训练集上逐步提升了Prompt的平均余弦相似度得分(best_score),最终在验证集上通过val_score验证了Prompt的泛化能力。具体效果包括:
性能提升:通过多次迭代,Prompt的平均相似度得分从初始值逐步提高,表明生成的总结更接近参考答案。 如图上日志显示 ✅ Improved prompt! New Score: 0.5748
可恢复性:通过progress_file保存每一步的最优Prompt和得分,便于在中断后继续优化。
泛化能力:最终Prompt在验证集上的得分(val_score)反映了其在未见过数据上的表现,验证了优化的有效性。
执行结果
{"step": 0, "best_prompt": "You are a senior support engineer. Given the following ticket content including subject, description, requester, and message history, summarize the root cause in clear, natural, and technical language. Keep the explanation concise and professional.", "best_score": 0.0.5748027596473694} {"step": 1, "best_prompt": "You are a senior support engineer. Given the following ticket content including subject, description, requester, and message history, summarize the root cause in clear, natural, and technical language. keep the explanation concise and professional.summarize in the shortest possible sentence, beginning with \"The root cause of the issue is...\"", "best_score": 0.6720715522766113} {"step": 2, "best_prompt": "\"You are an experienced support engineer. Analyze the following ticket content, which includes the subject, description, requester, and message history. Provide a concise and precise summary of the root cause of the issue, using clear and technical language. Focus on identifying the underlying problem that led to the user's experience, ensuring that your summary aligns closely with the reference summary in terms of clarity and relevance.\"", "best_score": 0.690971531867981}