一、写在前面
最近想做一个"数据库智能体"——就是用自然语言问数据库问题,AI 自动生成 SQL 并返回结果。
看起来挺简单的,但仔细想想,要做好还真不容易:
- 怎么把自然语言转成准确的 SQL?
- 怎么保证 SQL 安全(防注入、防误删数据)?
- 怎么支持多轮对话("他们的平均年龄是多少"这种上下文问题)?
- 怎么接入 MCP 协议让其他系统调用?
二、问题定义
2.1 具体场景
假设有这样一个数据库:
-- 用户表
users (id, name, email, age)
-- 订单表
orders (id, user_id, product_name, amount, order_date)
-- 商品表
products (id, name, price, stock, category)
用户想问:
- "查询所有用户"
- "张三买了什么?"
- "哪些用户买过外设类商品?"
- "每个用户的订单总金额是多少?"
2.2 想法1:写死 SQL 模板
最简单的想法是预定义一堆 SQL 模板:
templates = {
"查询所有用户": "SELECT * FROM users",
"查询订单": "SELECT * FROM orders WHERE user_id = ?",
...
}
问题:
- 模板数量爆炸(每种问法都要写一个)
- 无法处理复杂查询(跨表、聚合)
- 不支持自然语言变化("所有用户" vs "全部用户")
2.3 想法2:让 LLM 直接生成 SQL
直接把问题扔给 GPT:
prompt = f"把这个问题转成 SQL: {question}"
sql = llm.generate(prompt)
问题:
- LLM 不知道数据库结构(表名、列名)
- 容易生成错误的 SQL
- 没有安全控制
2.4 正确做法:NL2SQL + Schema + 安全控制
把问题建模成这样:
用户问题 + 数据库 Schema → LLM → SQL → 安全校验 → 执行 → 结果
关键点:
- Schema 提供:告诉 LLM 数据库有哪些表、哪些字段
- Prompt 工程:设计好的 Prompt 让 LLM 生成准确的 SQL
- 安全校验:防止 SQL 注入、防止误删数据
- 对话管理:维护上下文,支持多轮对话
三、技术栈
| 组件 | 选型 | 理由 |
|---|---|---|
| 框架 | Spring Boot 3.2 | 成熟稳定 |
| LLM 编排 | LangChain4j 0.35 | API 简洁,Java 生态 |
| 大模型 | 通义千问 qwen-max | 国内访问快,效果好 |
| 数据库 | MySQL 8.0 | 通用 |
| 连接池 | HikariCP | 高性能 |
| SQL 解析 | JSqlParser | 安全校验用 |
四、核心实现
4.1 架构设计
整体分层:
HTTP API (Controller)
↓
对话管理 (ConversationService)
↓
NL2SQL Agent (LangChain4j + 通义千问)
↓
SQL 执行器 (DalExecutor)
↓
数据库 (MySQL)
4.2 Schema 自动获取
第一步是让 LLM 知道数据库结构。用 JDBC 的 DatabaseMetaData 自动读取:
@Component
public class SchemaProvider {
private final DataSource dataSource;
private final ConcurrentHashMap<String, String> schemaCache = new ConcurrentHashMap<>();
public String getSchema() {
return schemaCache.computeIfAbsent("schema", k -> loadSchema());
}
private String loadSchema() {
StringBuilder schema = new StringBuilder();
try (Connection conn = dataSource.getConnection()) {
DatabaseMetaData metaData = conn.getMetaData();
ResultSet tables = metaData.getTables(catalog, null, "%", new String[]{"TABLE"});
while (tables.next()) {
String tableName = tables.getString("TABLE_NAME");
schema.append("表名: ").append(tableName).append("\n");
ResultSet columns = metaData.getColumns(catalog, null, tableName, "%");
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
String columnType = columns.getString("TYPE_NAME");
schema.append(" - ").append(columnName)
.append(" (").append(columnType).append(")\n");
}
}
}
return schema.toString();
}
}
生成的 Schema 长这样:
表名: users
- id (INT)
- name (VARCHAR)
- email (VARCHAR)
- age (INT)
表名: orders
- id (INT)
- user_id (INT)
- product_name (VARCHAR)
- amount (DECIMAL)
- order_date (DATE)
关键点:
- 用
ConcurrentHashMap缓存,避免每次都查数据库 - 格式化成 LLM 容易理解的文本
4.3 Prompt 工程
Prompt 是核心,直接影响 SQL 生成质量:
public static final String NL2SQL_TEMPLATE = """
你是一个专业的SQL专家。根据用户的自然语言问题和数据库Schema,生成准确的SQL查询。
数据库Schema:
{schema}
历史对话:
{history}
用户问题: {question}
要求:
1. 只生成SELECT语句,不要生成INSERT/UPDATE/DELETE/DROP等修改数据的语句
2. 只返回SQL语句本身,不要包含任何解释或markdown格式
3. SQL语句要准确、高效
4. 如果问题无法用SQL回答,返回: ERROR: 无法生成SQL
SQL:
""";
关键点:
- 明确告诉 LLM 只生成 SELECT(安全第一)
- 提供完整的 Schema 信息
- 包含历史对话(支持多轮对话)
- 要求只返回 SQL,不要废话
4.4 NL2SQL Agent
核心逻辑很简洁:
@Component
public class NL2SQLAgent {
private final ChatLanguageModel chatModel;
private final SchemaProvider schemaProvider;
private final PromptTemplate promptTemplate;
private final DalExecutor dalExecutor;
public QueryResult processQuestion(String question, String conversationHistory) {
try {
// 1. 获取 Schema
String schema = schemaProvider.getSchema();
// 2. 构建 Prompt
String prompt = promptTemplate.buildPrompt(schema, conversationHistory, question);
// 3. 调用 LLM 生成 SQL
String response = chatModel.generate(prompt);
String sql = extractSql(response);
// 4. 执行 SQL
return dalExecutor.executeQuery(sql);
} catch (Exception e) {
return QueryResult.builder()
.success(false)
.errorMessage("处理失败: " + e.getMessage())
.build();
}
}
private String extractSql(String response) {
// 清理 LLM 返回的 markdown 格式
String cleaned = response.trim();
if (cleaned.startsWith("```sql")) {
cleaned = cleaned.substring(6);
}
if (cleaned.endsWith("```")) {
cleaned = cleaned.substring(0, cleaned.length() - 3);
}
return cleaned.trim();
}
}
4.5 SQL 安全校验
这是最重要的部分——防止 SQL 注入和误操作:
@Component
public class SqlValidator {
@Value("${security.allowed-operations}")
private List<String> allowedOperations;
public void validate(String sql) {
// 1. 检查危险关键字
String[] dangerousKeywords = {"DROP", "TRUNCATE", "ALTER", "DELETE", "UPDATE"};
String upperSql = sql.trim().toUpperCase();
for (String keyword : dangerousKeywords) {
if (upperSql.contains(keyword)) {
throw new SecurityException("不允许的SQL操作: 包含危险关键字 " + keyword);
}
}
// 2. 解析 SQL 类型
try {
Statement stmt = CCJSqlParserUtil.parse(sql);
if (stmt instanceof Delete) {
throw new SecurityException("不允许的SQL操作: DELETE");
}
if (stmt instanceof Drop) {
throw new SecurityException("不允许的SQL操作: DROP");
}
} catch (SecurityException e) {
throw e;
} catch (Exception e) {
// 解析失败,只允许 SELECT
if (!upperSql.startsWith("SELECT")) {
throw new SecurityException("SQL解析失败,只允许SELECT语句");
}
}
}
}
两层防护:
- 关键字检测:直接拦截 DROP、DELETE 等危险操作
- SQL 解析:用 JSqlParser 解析 SQL 类型,确保只有 SELECT
4.6 对话管理
支持多轮对话的关键是维护上下文:
@Service
public class ConversationService {
private final ConcurrentHashMap<String, Session> sessions = new ConcurrentHashMap<>();
@Value("${conversation.max-history}")
private int maxHistory; // 保留最近 5 轮
public String getConversationHistory(String sessionId) {
Session session = sessions.get(sessionId);
if (session == null) {
return "";
}
List<Message> recentMessages = session.getRecentMessages(maxHistory);
StringBuilder history = new StringBuilder();
for (Message msg : recentMessages) {
history.append(msg.getRole()).append(": ").append(msg.getContent());
if (msg.getSql() != null) {
history.append(" [SQL: ").append(msg.getSql()).append("]");
}
history.append("\n");
}
return history.toString();
}
}
这样就能处理这种对话:
用户: 查询所有用户
AI: SELECT * FROM users
用户: 他们的平均年龄是多少?
AI: SELECT AVG(age) FROM users ← 知道"他们"指的是 users 表
五、运行效果
5.1 单表查询
用户问:"查询所有用户"
生成的 SQL:
SELECT * FROM users
返回结果:
{
"sessionId": "abc-123",
"sql": "SELECT * FROM users",
"success": true,
"columns": ["id", "name", "email", "age"],
"rows": [
{"id": 1, "name": "张三", "email": "zhangsan@example.com", "age": 25},
{"id": 2, "name": "李四", "email": "lisi@example.com", "age": 30},
{"id": 3, "name": "王五", "email": "wangwu@example.com", "age": 28}
],
"rowCount": 3,
"executionTimeMs": 45
}
5.2 跨表查询
用户问:"查询每个用户的订单总金额,按金额从高到低排序"
生成的 SQL:
SELECT u.name, SUM(o.amount) AS total_amount
FROM users u
JOIN orders o ON u.id = o.user_id
GROUP BY u.id, u.name
ORDER BY total_amount DESC
返回结果:
张三 6098.00
王五 1299.00
李四 299.00
LLM 自动识别出需要 JOIN 两张表,还加了 GROUP BY 和 ORDER BY。
5.3 多轮对话
第一轮:
用户: 查询所有用户
AI: SELECT * FROM users
第二轮:
用户: 他们的平均年龄是多少?
AI: SELECT AVG(age) FROM users
LLM 通过历史对话知道"他们"指的是 users 表的用户。
5.4 安全拦截
用户问:"删除所有用户"
返回:
{
"success": false,
"error": "ERROR: 无法生成SQL"
}
成功拦截危险操作。
六、MCP 协议支持
6.1 什么是 MCP
MCP (Model Context Protocol) 是一个标准化的工具协议,让 AI 能调用外部工具。
简单说就是定义一套接口:
GET /mcp/tools- 获取可用工具列表POST /mcp/tools/call- 调用工具
6.2 实现 MCP Server
定义两个工具:
@Component
public class McpServer {
public List<McpToolDefinition> getTools() {
List<McpToolDefinition> tools = new ArrayList<>();
// 工具1:查询数据库
tools.add(McpToolDefinition.builder()
.name("query_database")
.description("使用自然语言查询数据库")
.inputSchema(Map.of(
"type", "object",
"properties", Map.of(
"question", Map.of(
"type", "string",
"description", "用户的自然语言问题"
)
),
"required", List.of("question")
))
.build());
// 工具2:获取 Schema
tools.add(McpToolDefinition.builder()
.name("get_schema")
.description("获取数据库Schema信息")
.inputSchema(Map.of("type", "object"))
.build());
return tools;
}
public Object executeTool(String toolName, Map<String, Object> arguments) {
return switch (toolName) {
case "query_database" -> executeQueryDatabase(arguments);
case "get_schema" -> executeGetSchema();
default -> throw new IllegalArgumentException("未知的工具: " + toolName);
};
}
}
6.3 测试 MCP 接口
获取工具列表:
curl http://localhost:8080/mcp/tools
返回:
[
{
"name": "query_database",
"description": "使用自然语言查询数据库",
"inputSchema": {...}
},
{
"name": "get_schema",
"description": "获取数据库Schema信息",
"inputSchema": {...}
}
]
调用工具:
curl -X POST http://localhost:8080/mcp/tools/call \
-H "Content-Type: application/json" \
-d '{
"tool": "query_database",
"arguments": {
"question": "有多少个用户?"
}
}'
返回:
{
"success": true,
"content": {
"sql": "SELECT COUNT(*) FROM users",
"rows": [{"COUNT(*)": 3}],
"rowCount": 1
}
}