通过QwenCoder+pg数据库实现文本转sql
QwenCoder 是阿里巴巴开源的通义千问系列 AI 编程大模型,可以将文字转换为sql语句,支持MySQL、postgresql等常用的数据库。
以下是我编写的一个测试代码,使用的本地pg库。使用的 XGenerationLab/XiYanSQL-QwenCoder-7B-2504 模型。
import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer
import psycopg2
from psycopg2.extras import RealDictCursor
def get_db_schema(conn, table_name):
"""获取指定表的结构信息"""
schema_info = []
with conn.cursor() as cur:
# 查询表结构
cur.execute(f"""
SELECT
c.column_name,c.data_type,pgd.description as description
FROM
information_schema.columns c
LEFT JOIN
pg_description pgd ON
pgd.objoid = (SELECT oid FROM pg_class WHERE relname = c.table_name)
AND pgd.objsubid = c.ordinal_position
WHERE
c.table_name = 't_license' -- 替换为你的表名
ORDER BY
c.ordinal_position;
""", (table_name,))
columns = cur.fetchall()
schema_info.append(f"表名:{table_name}")
schema_info.append("字段:")
for col in columns:
col_name, data_type, description = col
schema_info.append(f"- {col_name} ({data_type},{description})")
return "\n".join(schema_info)
def get_sample_data(conn, table_name):
"""获取表中的一条示例数据"""
with conn.cursor(cursor_factory=RealDictCursor) as cur:
cur.execute(f"SELECT * FROM {table_name} LIMIT 1")
data = cur.fetchone()
if data:
return {k: str(v) for k, v in data.items()}
return None
# 数据库连接配置
db_config = {
"host": "localhost",
"database": "test", # 替换为你的数据库名
"user": "postgres", # 替换为你的用户名
"password": "123456", # 替换为你的密码
"port": "5432"
}
lic_table = "public.t_lice"
dic_table = "public.t_dict"
conn = psycopg2.connect(**db_config)
lic_schema = get_db_schema(conn, lic_table)
dic_schema = get_db_schema(conn, dic_table)
lic_data = get_sample_data(conn, lic_table)
dic_data = get_sample_data(conn, dic_table)
dialect = '数据分析邻域'
evidence = f'''
数据示例
{lic_table}表:{lic_data}
{dic_table}表:{dic_data}
'''
# question = '查询首次制证日期在2025-01-01以后的零售户总数'
# question = '近一个月到期的零售户有多少?'
# question = '查询近三天到期的零售户许可证号、店名、地址、联系电话'
question = '查询近三天到期的正常经营的零售户许可证号、店名、地址、联系电话'
nl2sqlite_template_cn = f"""你是一名{dialect}专家,现在需要阅读并理解下面的【数据库schema】描述,以及可能用到的【参考信息】,并运用{dialect}知识生成sql语句回答【用户问题】。
【数据库schema】
{lic_schema}
{dic_schema}
【参考信息】
{evidence}
【用户问题】
{question}
```sql"""
model_name = "XGenerationLab/XiYanSQL-QwenCoder-7B-2504"
# model_name = "XGenerationLab/XiYanSQL-QwenCoder-3B-2502"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
## dialects -> ['SQLite', 'PostgreSQL', 'MySQL']
prompt = nl2sqlite_template_cn
message = [{'role': 'user', 'content': prompt}]
text = tokenizer.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=4096,
temperature=0.1,
top_p=0.8,
do_sample=True,
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)
主要是拼接提示词,我使用了两个方法来查询数据库的表结构和示例数据,作为提示词的一部分。
提示词主要包括以下三个部分: