通过QwenCoder+pg数据库实现文本转sql

4天前学习32

QwenCoder 是阿里巴巴开源的通义千问系列 AI 编程大模型,可以将文字转换为sql语句,支持MySQL、postgresql等常用的数据库。

以下是我编写的一个测试代码,使用的本地pg库。使用的  XGenerationLab/XiYanSQL-QwenCoder-7B-2504  模型。

import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer
import psycopg2
from psycopg2.extras import RealDictCursor

def get_db_schema(conn, table_name):
    """获取指定表的结构信息"""
    schema_info = []
    with conn.cursor() as cur:
        # 查询表结构
        cur.execute(f"""
            SELECT 
    c.column_name,c.data_type,pgd.description as description
FROM 
    information_schema.columns c
LEFT JOIN 
    pg_description pgd ON 
        pgd.objoid = (SELECT oid FROM pg_class WHERE relname = c.table_name)
        AND pgd.objsubid = c.ordinal_position
WHERE 
    c.table_name = 't_license'  -- 替换为你的表名
ORDER BY 
    c.ordinal_position;
        """, (table_name,))
        
        columns = cur.fetchall()
        schema_info.append(f"表名:{table_name}")
        schema_info.append("字段:")
        for col in columns:
            col_name, data_type, description = col
            schema_info.append(f"- {col_name} ({data_type},{description})")
    
    return "\n".join(schema_info)

def get_sample_data(conn, table_name):
    """获取表中的一条示例数据"""
    with conn.cursor(cursor_factory=RealDictCursor) as cur:
        cur.execute(f"SELECT * FROM {table_name} LIMIT 1")
        data = cur.fetchone()
        if data:
            return {k: str(v) for k, v in data.items()}
    return None

# 数据库连接配置
db_config = {
    "host": "localhost",
    "database": "test",  # 替换为你的数据库名
    "user": "postgres",      # 替换为你的用户名
    "password": "123456",  # 替换为你的密码
    "port": "5432"
}

lic_table = "public.t_lice"
dic_table = "public.t_dict"

conn = psycopg2.connect(**db_config)

lic_schema = get_db_schema(conn, lic_table)
dic_schema = get_db_schema(conn, dic_table)

lic_data = get_sample_data(conn, lic_table)
dic_data = get_sample_data(conn, dic_table)

dialect = '数据分析邻域'
evidence = f'''
数据示例
{lic_table}表:{lic_data}
{dic_table}表:{dic_data}
'''

# question = '查询首次制证日期在2025-01-01以后的零售户总数'
# question = '近一个月到期的零售户有多少?'
# question = '查询近三天到期的零售户许可证号、店名、地址、联系电话'
question = '查询近三天到期的正常经营的零售户许可证号、店名、地址、联系电话'

nl2sqlite_template_cn = f"""你是一名{dialect}专家,现在需要阅读并理解下面的【数据库schema】描述,以及可能用到的【参考信息】,并运用{dialect}知识生成sql语句回答【用户问题】。
【数据库schema】
{lic_schema}
{dic_schema}

【参考信息】
{evidence}

【用户问题】
{question}

```sql"""

model_name = "XGenerationLab/XiYanSQL-QwenCoder-7B-2504"
# model_name = "XGenerationLab/XiYanSQL-QwenCoder-3B-2502"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

## dialects -> ['SQLite', 'PostgreSQL', 'MySQL']
prompt = nl2sqlite_template_cn
message = [{'role': 'user', 'content': prompt}]

text = tokenizer.apply_chat_template(
    message,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=4096,
    temperature=0.1,
    top_p=0.8,
    do_sample=True,
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

 

主要是拼接提示词,我使用了两个方法来查询数据库的表结构和示例数据,作为提示词的一部分。

提示词主要包括以下三个部分:

【数据库schema】数据库描述
【参考信息】参考信息
【用户问题】用户提问
最终运行结果:
 
可以看到运行结果还是挺稳定的。

 

扫描二维码推送至手机访问。

版权声明:本文由星光下的赶路人发布,如需转载请注明出处。

本文链接:https://forstyle.cc/zblog/post/89.html

分享给朋友:
返回列表

上一篇:对funasr微调的模型进行压缩

没有最新的文章了...