21天驯化AI打工仔 - SQEP与symbol编码性能测试

"007,我们需要讨论一个重要的性能优化问题,"我一边敲击键盘一边对我的 AI 助手说道。

"什么问题?我已经准备好了,"007 回应道,它的语音合成器发出了一种几乎可以称为热情的声音。

"在量化交易系统中,数据查询性能至关重要。我们需要测试一下股票代码编码方式对查询速度的影响。"


前三天,我们讨论了如何从 Tushare 获取 OHLC (开盘价、最高价、最低价、收盘价) 数据和调整因子 (adj_factor)。当时我们存储的数据结构如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
{
    "timestamp": "时间戳",
    "ts_code": "股票代码",
    "ohlc": {
        "ts_code": "股票代码",
        "open": "开盘价",
        "high": "最高价",
        "low": "最低价",
        "close": "收盘价",
        "vol": "成交量"
    },
    "adj_factor": {
        "ts_code": "股票代码",
        "trade_date": "交易日期",
        "adj_factor": "复权因子"
    }
}

现在,我们需要设计一种通用的数据交换格式(Standard Quotes Exchange Protocol, SQEP)。这种格式的工作原理是:由数据生产者(因为只有生产者才了解原始数据的具体格式)将数据转换为这种标准格式,然后再将其推送到 Redis 中供消费者使用。

"在金融数据处理中,毫秒级的延迟可能意味着巨大的差异,"007 解释道,"特别是在高频交易场景下,查询性能的优化至关重要。"

我点点头:"没错,其中一个关键问题是股票代码 symbol 的数据类型选择。理论上,整数类型的查询和比较操作应该比字符串更高效,但我们需要实际测试来验证这一点。"

"我可以帮你设计一个严谨的实验,"007 建议道,"我们可以生成足够大的数据集,然后在相同条件下比较 string 和 int64 类型的查询性能。"

于是,我和 007 决定生成 1 亿条股票数据,对 string 类型和 int64 类型的 symbol 进行严谨公平的性能测试,以量化分析不同编码方式对查询效率的实际影响。这个实验将帮助我们在构建高性能量化交易系统时做出更明智的技术选择。

1. SQEP-BAR-DAY 日线场景下的数据交换格式

"设计一个好的数据交换格式需要考虑多方面因素,"007 分析道,"包括数据完整性、传输效率、存储空间和查询性能。"

SQEP-BAR-DAY 是标准行情交换协议 (Standard Quotes Exchange Protocol) 中用于日线数据的格式规范。该格式设计用于在不同系统组件间高效传输和处理股票日线数据,确保数据的一致性和互操作性。

这种标准化的数据格式解决了量化交易系统中一个常见的痛点:不同数据源提供的数据格式各不相同,导致系统需要为每个数据源编写特定的处理逻辑。通过 SQEP,我们可以将这种复杂性隔离在数据生产者端,让消费者端的代码更加简洁和通用。

1.1. 字段定义

SQEP-BAR-DAY 包含以下标准字段:

字段名 数据类型 说明
symbol str/int 股票代码。推荐使用整型编码以提高性能
frame datetime.date 交易日期
open float64 开盘价
high float64 最高价
low float64 最低价
close float64 收盘价
vol float64 成交量
amount float64 成交额
adjust float64 复权因子
st bool 是否为 ST 股票(可选扩展字段)
buy_limit float64 涨停价(可选扩展字段)
sell_limit float64 跌停价(可选扩展字段)

1.2. 编码约定

"在设计数据结构时,命名和编码约定看似小事,但实际上对系统的长期维护和性能有着深远影响,"007 提醒道。

  1. 字段命名:使用 frame 而非 datetimestamp,因为后两者在某些数据库中不适合作为列名。

  2. 股票代码编码:为提高查询性能,推荐将字符串格式的股票代码转换为整型:

  3. 上海证券交易所:000001.SH → 1000001
  4. 深圳证券交易所:000001.SZ → 2000001

这种编码方式最多可支持 9 个不同交易所(数字 1-9,0 不能用作前缀)。

"这种整数编码方案非常巧妙,"我评论道,"它不仅保留了原始代码的所有信息,还能通过前缀区分不同交易所,同时将字符串转换为可能更高效的整数类型。"

007 补充道:"是的,而且这种编码方式在保持可读性的同时,还能充分利用数据库的整数索引优化,理论上应该能提高查询性能。不过,我们需要通过实际测试来验证这一假设。"

1.3. 使用场景

SQEP-BAR-DAY 主要应用于:

  1. 数据生产者(如 Tushare、QMT 等数据源)将原始数据转换为标准格式
  2. 通过 Redis 等中间件在系统组件间传输
  3. 数据消费者(如分析引擎、回测系统)处理标准格式数据
  4. 存储到 ClickHouse 等时序数据库中进行长期保存

2. symbol 编码的性能测试

"理论上,整数类型的查询应该比字符串更快,但具体提升多少呢?"我思考着。

"这正是我们需要通过实验来回答的问题,"007 回应道,"让我们设计一个严谨的性能测试,确保结果具有科学性和可重复性。"

在数据库性能优化中,选择合适的数据类型对查询效率有着显著影响。特别是对于像股票代码这样频繁用于查询条件的字段,其存储和索引方式可能会直接影响系统的整体性能。本节将通过大规模数据集测试,量化分析字符串型和整数型股票代码在查询性能上的差异。

2.1. 测试环境

"为了确保测试的公平性,我们需要设计一个完整的数据流程,"007 建议道,"从数据生成、存储到查询,每个环节都需要精心设计。"

首先,我们通过 Redis 来把 symbol 编码的数据进行存储,然后通过 ClickHouse 来进行查询。ClickHouse 是一个专为 OLAP(在线分析处理)设计的列式数据库管理系统,非常适合大规模数据的快速分析查询。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
# 配置参数
REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_PASSWORD = "添加Redis密码"  # 添加 Redis 密码
REDIS_QUEUE_NAME = "test_queue"

CLICKHOUSE_HOST = "localhost"
CLICKHOUSE_PORT = 9000
CLICKHOUSE_DB = "test_data"
CLICKHOUSE_TABLE = "stock_data"

BATCH_SIZE = 100000  # 每批次处理的记录数
TOTAL_RECORDS = 100000000  # 总记录数 (1 亿)

# 初始化 Redis 连接
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, decode_responses=True)

# 初始化 ClickHouse 连接
clickhouse_client = Client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, database=CLICKHOUSE_DB)

2.2. 生成 stock_data 数据

"生成有代表性的测试数据是实验成功的关键,"007 指出,"我们需要模拟真实市场环境中的数据分布和特征。"

为了测试查询速度,我们需要生成 stock_data 数据,这里我随机生成 1 亿条数据。这个规模足够大,可以消除随机波动的影响,同时也能反映真实生产环境中的性能表现。

"1 亿条记录,"我思考着,"这应该足够模拟真实的生产环境了。"

"是的,"007 回应,"大规模数据集能更好地展现不同编码方式在索引和查询上的性能差异。我已经设计好了数据生成算法,包括股票代码、交易日期、价格和交易量等关键字段。"

2.2.1. 生成股票和指数代码
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""生成股票和指数代码(沪市和深市各一半),返回symbols"""
def generate_symbols(num_stocks=5000):
    symbols = []

    for i in range(num_stocks):
        if i < num_stocks // 2:
            # 沪市股票 (600xxx, 601xxx, 603xxx, 605xxx, 688xxx)
            prefix = random.choice(['600', '601', '603', '605', '688'])
            suffix = str(random.randint(0, 999)).zfill(3)
            symbols.append(f"{prefix}{suffix}.SH")
        else:
            # 深市股票 (000xxx, 001xxx, 002xxx, 003xxx, 300xxx, 301xxx)
            prefix = random.choice(['000', '001', '002', '003', '300', '301'])
            suffix = str(random.randint(0, 999)).zfill(3)
            symbols.append(f"{prefix}{suffix}.SZ")

    # 添加主要指数
    indices = [
        # 上证指数
        "000001.SH",  # 上证综指
        "000016.SH",  # 上证50
        "000300.SH",  # 沪深300
        "000905.SH",  # 中证500
        "000852.SH",  # 中证1000
        # 深证指数
        "399001.SZ",  # 深证成指
        "399006.SZ",  # 创业板指
        "399300.SZ",  # 沪深300
        "399905.SZ",  # 中证500
        # 其他重要指数
        "000688.SH",  # 科创50
        "399673.SZ",  # 创业板50
    ]

    # 添加指数到符号列表
    symbols.extend(indices)

    return symbols
2.2.2. 生成ST股票列表
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
"""生成ST股票列表,返回st_stocks"""
def generate_st_stocks(symbols, st_ratio=0.05):
    """
    Args:
        symbols: 所有股票代码列表
        st_ratio: ST股票的比例

    Returns:
        ST股票代码集合
    """
    # 过滤掉指数
    stock_symbols = [s for s in symbols if not s in ["000001.SH", "000016.SH", "000300.SH", "000905.SH",
                                                    "000852.SH", "399001.SZ", "399006.SZ", "399300.SZ",
                                                    "399905.SZ", "000688.SH", "399673.SZ"]]

    # 随机选择一部分股票作为ST股票
    st_count = int(len(stock_symbols) * st_ratio)
    st_stocks = set(random.sample(stock_symbols, st_count))

    return st_stocks
2.2.3. 生成股票数据
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def generate_batch_data(batch_size, start_idx, symbols, trading_days, st_stocks):
    """
    Args:
        batch_size: 批次大小
        start_idx: 起始索引
        symbols: 股票和指数代码列表
        trading_days: 交易日列表
        st_stocks: ST股票代码集合

    Returns:
        包含股票数据的CSV字符串
    """
    # 准备CSV输出
    output = io.StringIO()
    csv_writer = csv.writer(output)

    # 生成数据
    for i in range(batch_size):
        if (start_idx + i) >= TOTAL_RECORDS:
            break

        # 随机选择股票/指数和日期
        symbol = random.choice(symbols)
        trade_date = random.choice(trading_days)

        # 判断是否为指数
        is_index = symbol in ["000001.SH", "000016.SH", "000300.SH", "000905.SH", "000852.SH",
                             "399001.SZ", "399006.SZ", "399300.SZ", "399905.SZ", "000688.SH", "399673.SZ"]

        # 判断是否为ST股票
        is_st = symbol in st_stocks

        # 为指数生成不同范围的价格
        if is_index:
            if "000001.SH" in symbol:  # 上证综指
                base_price = random.uniform(2000, 6000)
            elif "399001.SZ" in symbol:  # 深证成指
                base_price = random.uniform(6000, 15000)
            elif "399006.SZ" in symbol:  # 创业板指
                base_price = random.uniform(1000, 3500)
            elif "000300.SH" in symbol or "399300.SZ" in symbol:  # 沪深300
                base_price = random.uniform(3000, 5500)
            else:  # 其他指数
                base_price = random.uniform(4000, 10000)

            # 指数波动通常较小
            high = base_price * random.uniform(1, 1.03)
            low = base_price * random.uniform(0.97, 1)
            open_price = random.uniform(low, high)
            close = random.uniform(low, high)

            # 指数成交量和金额较大
            vol = random.uniform(10000000, 100000000)
            amount = random.uniform(100000000, 1000000000)
            adjust = 1.0  # 指数没有复权因子

            # 指数没有ST状态和涨跌停限制
            st = False
            buy_limit = 0.0
            sell_limit = 0.0
        else:
            # 普通股票
            # ST股票价格通常较低
            if is_st:
                open_price = random.uniform(1, 10)
            else:
                open_price = random.uniform(5, 100)

            high = open_price * random.uniform(1, 1.1)
            low = open_price * random.uniform(0.9, 1)
            close = random.uniform(low, high)

            # 生成成交量和金额
            vol = random.uniform(10000, 10000000)
            amount = vol * close * random.uniform(0.9, 1.1)

            # 生成复权因子
            adjust = random.uniform(0.8, 1.2)

            # ST状态
            st = is_st

            # 涨跌停限制
            if is_st:
                # ST股票涨跌停限制为5%
                limit_pct = 0.05
            else:
                # 普通股票涨跌停限制为10%(创业板20%,但简化处理)
                limit_pct = 0.1

            # 计算涨跌停价格
            prev_close = close / random.uniform(0.95, 1.05)  # 模拟前一日收盘价
            buy_limit = round(prev_close * (1 + limit_pct), 2)  # 涨停价
            sell_limit = round(prev_close * (1 - limit_pct), 2)  # 跌停价

        # 写入CSV
        csv_writer.writerow([
            symbol,
            trade_date.strftime('%Y-%m-%d'),
            round(open_price, 2),
            round(high, 2),
            round(low, 2),
            round(close, 2),
            round(vol, 2),
            round(amount, 2),
            round(adjust, 4),
            1 if st else 0,  # 布尔值转为0/1
            round(buy_limit, 2),
            round(sell_limit, 2)
        ])

    return output.getvalue()
2.2.4. 生成交易日
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
print(f"开始生成{TOTAL_RECORDS}条股票和指数数据...")

# 生成股票和指数代码
symbols = generate_symbols(5000)
print(f"已生成{len(symbols)}个股票和指数代码")

# 生成ST股票
st_stocks = generate_st_stocks(symbols)
print(f"已生成{len(st_stocks)}只ST股票")

# 生成交易日
start_date = datetime(2010, 1, 1)
trading_days = []

# 生成约2500个交易日 (每年约250个交易日,10年约2500个)
current_date = start_date
for _ in range(2500):
    # 跳过周末
    if current_date.weekday() < 5:  # 0-4 表示周一至周五
        trading_days.append(current_date)
    current_date += timedelta(days=1)

print(f"已生成{len(trading_days)}个交易日")

2.3. Redis 生产者:把数据推送到 Redis 队列

"数据生成后,我们需要一个高效的方式将其传输到数据库,"007 解释道,"Redis 作为中间件非常适合这个场景,它能高效处理大量数据的快速传输。"

我点点头:"使用生产者-消费者模式也能让我们更好地控制数据流,避免一次性加载过多数据导致系统压力过大。"

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 计算需要的批次数
num_batches = (TOTAL_RECORDS + BATCH_SIZE - 1) // BATCH_SIZE

# 使用tqdm显示进度
with tqdm(total=num_batches) as pbar:
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE

        # 生成批次数据
        batch_data = generate_batch_data(BATCH_SIZE, start_idx, symbols, trading_days, st_stocks)

        # 创建唯一的批次ID
        batch_id = str(uuid.uuid4())

        # 推送到Redis
        redis_client.set(f"stock_data:{batch_id}", batch_data, ex=3600)  # 设置1小时过期
        redis_client.lpush(REDIS_QUEUE_NAME, batch_id)

        pbar.update(1)

        # 每10个批次暂停一下,避免Redis压力过大
        if batch_idx % 10 == 9:
            time.sleep(1)

print("数据生成完成,已全部推送到Redis队列")

"这个批处理设计很巧妙,"007 评论道,"通过 UUID 确保了每个批次的唯一性,同时设置过期时间避免了数据长期占用内存。每处理 10 个批次暂停一下也能防止 Redis 服务器压力过大。"

大概运行 20 分钟左右,可以运行结束。数据生成过程中,我们可以看到进度条稳步前进,表明数据正在被成功生成并推送到 Redis 队列中。

"数据生成和推送过程顺利完成,"我观察着进度条,"现在我们需要设计消费者端来处理这些数据。"

2.4. Redis 消费者:从 Redis 队列中读取数据并写入 clickhouse 数据库

"消费者端的设计同样重要,"007 说道,"它需要高效地从 Redis 读取数据,并正确地将其写入 ClickHouse 数据库。"

"是的,而且我们需要确保数据的完整性和一致性,"我补充道,"特别是在处理如此大量的数据时。"

007 点点头:"我已经设计了一个健壮的消费者程序,它能够处理各种异常情况,并提供实时的进度反馈。"

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def create_table_if_not_exists():
    """创建ClickHouse表(如果不存在)"""
    clickhouse_client.execute("""
    CREATE TABLE IF NOT EXISTS stock_data (
        symbol String,
        frame Date,
        open Float64,
        high Float64,
        low Float64,
        close Float64,
        vol Float64,
        amount Float64,
        adjust Float64,
        st UInt8,
        buy_limit Float64,
        sell_limit Float64,
        is_index UInt8 MATERIALIZED if(match(symbol, '^(000001|000016|000300|000905|000852|000688|399001|399006|399300|399905|399673)\\.(SH|SZ)$'), 1, 0)
    ) ENGINE = MergeTree()
    PARTITION BY toYYYYMM(frame)
    ORDER BY (symbol, frame)
    """
    )

    # 创建索引视图,方便查询
    clickhouse_client.execute("""
    CREATE VIEW IF NOT EXISTS stock_view AS
    SELECT
        symbol,
        frame,
        open,
        high,
        low,
        close,
        vol,
        amount,
        adjust,
        st,
        buy_limit,
        sell_limit,
        is_index,
        splitByChar('.', symbol)[1] AS code,
        splitByChar('.', symbol)[2] AS exchange
    FROM stock_data
    """)

    print("ClickHouse表和视图已准备就绪")

2.4.1. 处理一个批次的数据
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def process_batch(batch_id):
    """
    Args:
        batch_id: Redis中的批次ID

    Returns:
        处理的记录数
    """
    # 从Redis获取数据
    csv_data = redis_client.get(f"stock_data:{batch_id}")
    if not csv_data:
        return 0

    # 解析CSV数据
    csv_reader = csv.reader(io.StringIO(csv_data))
    rows = []

    for row in csv_reader:
        try:
            # 将日期字符串转换为日期对象
            date_str = row[1]
            # 尝试解析不同格式的日期
            try:
                # 尝试解析YYYY-MM-DD格式
                date_obj = datetime.strptime(date_str, '%Y-%m-%d').date()
            except ValueError:
                try:
                    # 尝试解析YYYYMMDD格式
                    date_obj = datetime.strptime(date_str, '%Y%m%d').date()
                except ValueError:
                    # 如果都失败,打印错误信息并跳过此行
                    print(f"无法解析日期: {date_str},跳过此行")
                    continue

            # 转换数据类型
            rows.append((
                row[0],                    # symbol (String)
                date_obj,                  # frame (Date as date object)
                float(row[2]),             # open (Float64)
                float(row[3]),             # high (Float64)
                float(row[4]),             # low (Float64)
                float(row[5]),             # close (Float64)
                float(row[6]),             # vol (Float64)
                float(row[7]),             # amount (Float64)
                float(row[8]),             # adjust (Float64)
                int(row[9]),               # st (UInt8)
                float(row[10]),            # buy_limit (Float64)
                float(row[11])             # sell_limit (Float64)
            ))
        except Exception as e:
            print(f"处理行数据时出错: {e}, 行数据: {row}")
            continue

    # 插入到ClickHouse
    if rows:
        try:
            clickhouse_client.execute(
                f"INSERT INTO {CLICKHOUSE_TABLE} (symbol, frame, open, high, low, close, vol, amount, adjust, st, buy_limit, sell_limit) VALUES",
                rows
            )
        except Exception as e:
            print(f"插入ClickHouse时出错: {e}")
            return 0

    # 删除Redis中的数据
    redis_client.delete(f"stock_data:{batch_id}")

    return len(rows)
2.4.2. 把 Redis 队列中的数据存入 stock_data 表中
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
create_table_if_not_exists()

print("开始从Redis队列消费数据并写入ClickHouse...")
total_processed = 0
start_time = time.time()

try:
    with tqdm() as pbar:
        while True:
            # 从队列中获取批次ID
            result = redis_client.brpop(REDIS_QUEUE_NAME, timeout=5)

            if not result:
                # 队列为空,检查是否还有未处理的批次
                if redis_client.llen(REDIS_QUEUE_NAME) == 0:
                    print("队列为空,等待新数据...")
                    time.sleep(5)
                    if time.time() - start_time > 60 and total_processed == 0:
                        print("1分钟内没有数据,退出程序")
                        break
                    continue

            # 处理批次
            _, batch_id = result
            records_processed = process_batch(batch_id)
            total_processed += records_processed

            # 更新进度条
            pbar.update(records_processed)
            pbar.set_description(f"已处理: {total_processed:,}")

            # 每处理100万条记录,显示一次统计信息
            if total_processed % 1000000 == 0:
                elapsed = time.time() - start_time
                rate = total_processed / elapsed if elapsed > 0 else 0
                print(f"\n已处理 {total_processed:,} 条记录,速率: {rate:.2f} 记录/秒")

                # 查询一些统计信息
                try:
                    stock_count = clickhouse_client.execute("SELECT count(DISTINCT symbol) FROM stock_data WHERE is_index = 0")[0][0]
                    index_count = clickhouse_client.execute("SELECT count(DISTINCT symbol) FROM stock_data WHERE is_index = 1")[0][0]
                    st_count = clickhouse_client.execute("SELECT count(DISTINCT symbol) FROM stock_data WHERE st = 1")[0][0]
                    date_range = clickhouse_client.execute("SELECT min(frame), max(frame) FROM stock_data")[0]

                    print(f"数据统计: {stock_count} 只股票, {index_count} 个指数, {st_count} 只ST股票")
                    print(f"日期范围: {date_range[0]}{date_range[1]}")

                    # 查询涨跌停统计
                    limit_stats = clickhouse_client.execute("""
                    SELECT
                        count() as total_records,
                        sum(if(close >= buy_limit AND buy_limit > 0, 1, 0)) as up_limit_count,
                        sum(if(close <= sell_limit AND sell_limit > 0, 1, 0)) as down_limit_count
                    FROM stock_data
                    WHERE is_index = 0
                    """)[0]

                    if limit_stats[0] > 0:
                        up_limit_pct = (limit_stats[1] / limit_stats[0]) * 100
                        down_limit_pct = (limit_stats[2] / limit_stats[0]) * 100
                        print(f"涨停比例: {up_limit_pct:.2f}%, 跌停比例: {down_limit_pct:.2f}%")
                except Exception as e:
                    print(f"统计查询失败: {str(e)}")

except KeyboardInterrupt:
    print("\n程序被手动中断")

except Exception as e:
    print(f"\n程序执行异常: {str(e)}")

finally:
    elapsed = time.time() - start_time
    rate = total_processed / elapsed if elapsed > 0 else 0
    print(f"\n总计处理 {total_processed:,} 条记录,总耗时: {elapsed:.2f} 秒,平均速率: {rate:.2f} 记录/秒")

    # 显示最终统计信息
    try:
        total_records = clickhouse_client.execute("SELECT count() FROM stock_data")[0][0]
        print(f"ClickHouse中总记录数: {total_records:,}")
    except Exception as e:
        print(f"最终统计查询失败: {str(e)}")

"数据已经成功导入 ClickHouse,"我看着终端输出说道,"现在我们有了一个包含大量记录的 stock_data 表。"

007 显示出一丝满意:"是的,数据导入过程顺利完成。现在我们可以进入实验的核心部分了:创建整数编码的表并进行性能对比。"

2.5. stock_data_with_int 表格的创建

"为了进行公平的性能比较,我们需要创建一个新表,结构与原表相同,但将 symbol 字段编码为整数类型,"007 解释道。

我点点头:"这样我们就能在相同的数据集上比较两种不同编码方式的查询性能了。"

我现在要把 ClickHouse 中所有的 symbol 编码成 int 类型,作为 symbol_int 字段进行存储,然后测试 ClickHouse 两种不同编码方式的查询速度。这个过程需要精心设计,确保两个表除了股票代码的编码方式外,其他方面完全相同,以保证测试的公平性。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def encode_symbol(symbol: str) -> int:
    """将字符串格式的股票代码转换为整型编码

    Args:
        symbol: 股票代码,如 '000001.SZ' 或 '600519.SH'

    Returns:
        整型编码的股票代码,如 2000001 或 1600519
    """
    code, exchange = symbol.split('.')
    code = code.lstrip('0')  # 移除前导零,但保留至少一位数字
    if not code:
        code = '0'

    if exchange.upper() == 'SH':
        prefix = '1'
    elif exchange.upper() == 'SZ':
        prefix = '2'
    else:
        raise ValueError(f"不支持的交易所: {exchange}")

    return int(prefix + code)

def decode_symbol(encoded_symbol: int) -> str:
    """将整型编码的股票代码转换回字符串格式

    Args:
        encoded_symbol: 整型编码的股票代码,如 2000001

    Returns:
        字符串格式的股票代码,如 '000001.SZ'
    """
    encoded_str = str(encoded_symbol)
    prefix = encoded_str[0]
    code = encoded_str[1:]

    # 补齐6位数字
    code = code.zfill(6)

    if prefix == '1':
        exchange = 'SH'
    elif prefix == '2':
        exchange = 'SZ'
    else:
        raise ValueError(f"不支持的交易所前缀: {prefix}")

    return f"{code}.{exchange}"
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""创建一个新表,将symbol编码为symbol_int,保留前导零"""
def create_table_with_symbol_int_preserve_zeros(clickhouse_client):
    try:
        # 创建一个新表,包含symbol_int列
        print("创建新表...")
        clickhouse_client.execute("""
        CREATE TABLE IF NOT EXISTS stock_data_with_int (
            symbol String,
            symbol_int Int64,  # 使用Int64以确保能存储足够大的数字
            frame Date,
            open Float64,
            high Float64,
            low Float64,
            close Float64,
            vol Float64,
            amount Float64,
            adjust Float64,
            st UInt8,
            buy_limit Float64,
            sell_limit Float64,
            is_index UInt8 MATERIALIZED if(match(symbol, '^(000001|000016|000300|000905|000852|000688|399001|399006|399300|399905|399673)\\.(SH|SZ)$'), 1, 0)
        ) ENGINE = MergeTree()
        PARTITION BY toYYYYMM(frame)
        ORDER BY (symbol, frame)
        """)

        # 从原表插入数据,同时计算symbol_int
        print("从原表复制数据并计算symbol_int(保留前导零)...")

        # 使用正确的SQL语法计算symbol_int,保留前导零
        clickhouse_client.execute("""
        INSERT INTO stock_data_with_int
        SELECT
            symbol,
            -- 使用条件判断交易所,然后拼接前缀和完整的股票代码(包括前导零)
            CASE
                WHEN endsWith(symbol, '.SH') THEN toInt64(concat('1', splitByChar('.', symbol)[1]))
                WHEN endsWith(symbol, '.SZ') THEN toInt64(concat('2', splitByChar('.', symbol)[1]))
                ELSE 0
            END AS symbol_int,
            frame, open, high, low, close, vol, amount, adjust, st, buy_limit, sell_limit
        FROM stock_data
        """)

        print("新表创建并填充完成")

        # 创建索引
        print("在symbol_int列上创建索引...")
        clickhouse_client.execute("""
        ALTER TABLE stock_data_with_int
        ADD INDEX idx_symbol_int (symbol_int) TYPE minmax GRANULARITY 1
        """)

        # 验证数据
        result = clickhouse_client.execute("""
        SELECT
            symbol,
            symbol_int,
            count()
        FROM stock_data_with_int
        GROUP BY symbol, symbol_int
        ORDER BY count() DESC
        LIMIT 10
        """)

        print("\n验证数据(前10条):")
        for row in result:
            print(f"Symbol: {row[0]}, Symbol Int: {row[1]}, Count: {row[2]}")

    except Exception as e:
        print(f"创建新表时出错: {str(e)}")


create_table_with_symbol_int_preserve_zeros(clickhouse_client)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""重新创建表,不包含symbol列"""
def recreate_table_without_symbol():
    try:
        # 获取表结构
        columns = clickhouse_client.execute("DESCRIBE TABLE stock_data_with_int")

        # 创建新表定义,排除symbol列
        column_defs = []
        for col in columns:
            col_name = col[0]
            col_type = col[1]

            # 跳过symbol列
            if col_name == 'symbol':
                continue

            # 处理is_index列,使其基于symbol_int
            if col_name == 'is_index':
                column_defs.append(f"{col_name} UInt8 MATERIALIZED if(symbol_int IN (1000001, 1000016, 1000300, 1000905, 1000852, 1000688, 2399001, 2399006, 2399300, 2399905, 2399673), 1, 0)")
            else:
                # 添加其他列
                column_defs.append(f"{col_name} {col_type}")

        # 创建新表
        print("创建新表stock_data_with_int_new...")
        create_table_sql = f"""
        CREATE TABLE stock_data_with_int_new (
            {', '.join(column_defs)}
        ) ENGINE = MergeTree()
        PARTITION BY toYYYYMM(frame)
        ORDER BY (symbol_int, frame)
        """

        clickhouse_client.execute(create_table_sql)

        # 复制数据,排除symbol列
        print("复制数据到新表...")
        copy_data_sql = """
        INSERT INTO stock_data_with_int_new
        SELECT
            symbol_int, frame, open, high, low, close, vol, amount, adjust, st, buy_limit, sell_limit
        FROM stock_data_with_int
        """

        clickhouse_client.execute(copy_data_sql)

        # 验证数据
        old_count = clickhouse_client.execute("SELECT count() FROM stock_data_with_int")[0][0]
        new_count = clickhouse_client.execute("SELECT count() FROM stock_data_with_int_new")[0][0]

        print(f"原表记录数: {old_count}")
        print(f"新表记录数: {new_count}")

        if old_count == new_count:
            # 删除旧表并重命名新表
            print("删除旧表并重命名新表...")
            clickhouse_client.execute("DROP TABLE stock_data_with_int")
            clickhouse_client.execute("RENAME TABLE stock_data_with_int_new TO stock_data_with_int")

            print("表重建成功,symbol列已移除")
        else:
            print("警告: 数据复制不完整,保留两个表以供检查")

    except Exception as e:
        print(f"重建表时出错: {str(e)}")

# 执行表重建
recreate_table_without_symbol()

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""为表设置symbol_int为索引"""
def add_index_to_stock_data_with_int():
    """为stock_data_with_int表添加symbol_int列的索引"""
    try:
        # 检查表是否存在
        table_exists = clickhouse_client.execute("EXISTS TABLE stock_data_with_int")[0][0]
        if not table_exists:
            print("表stock_data_with_int不存在")
            return

        # 检查symbol_int列是否存在
        columns = clickhouse_client.execute("DESCRIBE TABLE stock_data_with_int")
        has_symbol_int = False
        for col in columns:
            if col[0] == 'symbol_int':
                has_symbol_int = True
                print(f"找到symbol_int列,类型为: {col[1]}")
                break

        if not has_symbol_int:
            print("表中没有symbol_int列,无法创建索引")
            return

        # 添加索引
        print("为stock_data_with_int表添加symbol_int列的索引...")
        clickhouse_client.execute("""
        ALTER TABLE stock_data_with_int
        ADD INDEX idx_symbol_int (symbol_int) TYPE minmax GRANULARITY 1
        """)

        # 验证索引是否创建成功
        indexes = clickhouse_client.execute("""
        SELECT name, expr
        FROM system.data_skipping_indices
        WHERE table = 'stock_data_with_int' AND database = currentDatabase()
        """)

        if indexes:
            print("\n索引创建成功:")
            for idx in indexes:
                print(f"索引名称: {idx[0]}, 表达式: {idx[1]}")
        else:
            print("\n索引创建失败")

    except Exception as e:
        print(f"添加索引时出错: {str(e)}")

# 执行添加索引
add_index_to_stock_data_with_int()

"表创建成功,"我看着屏幕上的输出说道,"现在我们有了两个结构相似但股票代码编码方式不同的表。"

007 提醒道:"在进行性能测试前,我们需要确保两个表在结构、数据量和索引等方面是一致的,这样测试结果才有意义。"

"你说得对,"我点点头,"让我们进行全面的表结构对比,确保测试的公平性。"

2.6. 检查两张表格,确保测试前的数据是严谨公平的

"科学实验的关键在于控制变量,"007 解释道,"我们需要确保两个表除了股票代码的编码方式外,其他所有方面都相同。"

我认真地检查着两个表的结构:"是的,只有在确保其他条件完全相同的情况下,我们才能得出关于编码方式影响的可靠结论。"

"我设计了一系列检查函数,"007 说,"它们会全面比较两个表的结构、索引、分区方式和数据量,确保测试的严谨性。"

2.6.1. 表结构对比
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def compare_table_structures():
    """对比两个表的结构"""
    try:
        # 获取两个表的结构
        stock_data_columns = clickhouse_client.execute("DESCRIBE TABLE stock_data")
        stock_data_with_int_columns = clickhouse_client.execute("DESCRIBE TABLE stock_data_with_int")

        # 打印两个表的列信息
        print("stock_data表结构:")
        for col in stock_data_columns:
            print(f"列名: {col[0]}, 类型: {col[1]}, 默认表达式: {col[2] if len(col) > 2 else 'None'}")

        print("\nstock_data_with_int表结构:")
        for col in stock_data_with_int_columns:
            print(f"列名: {col[0]}, 类型: {col[1]}, 默认表达式: {col[2] if len(col) > 2 else 'None'}")

        # 检查列数是否相同(不考虑symbol/symbol_int)
        stock_data_col_count = len([c for c in stock_data_columns if c[0] != 'symbol'])
        stock_data_with_int_col_count = len([c for c in stock_data_with_int_columns if c[0] != 'symbol_int'])

        if stock_data_col_count != stock_data_with_int_col_count:
            print(f"\n警告: 两个表的列数不同 (不考虑symbol/symbol_int): stock_data有{stock_data_col_count}列,stock_data_with_int有{stock_data_with_int_col_count}列")
        else:
            print(f"\n两个表的列数相同 (不考虑symbol/symbol_int): 各有{stock_data_col_count}列")

        return stock_data_columns, stock_data_with_int_columns

    except Exception as e:
        print(f"对比表结构时出错: {str(e)}")
        return None, None

# 执行表结构对比
stock_data_columns, stock_data_with_int_columns = compare_table_structures()

2.6.2. 数据量对比
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def compare_data_volume():
    """对比两个表的数据量"""
    try:
        # 获取两个表的记录数
        stock_data_count = clickhouse_client.execute("SELECT count() FROM stock_data")[0][0]
        stock_data_with_int_count = clickhouse_client.execute("SELECT count() FROM stock_data_with_int")[0][0]

        print(f"stock_data表记录数: {stock_data_count:,}")
        print(f"stock_data_with_int表记录数: {stock_data_with_int_count:,}")

        if stock_data_count != stock_data_with_int_count:
            print(f"警告: 两个表的记录数不同,相差{abs(stock_data_count - stock_data_with_int_count):,}条记录")
        else:
            print("两个表的记录数相同")

        # 检查不同股票代码的数量
        stock_data_symbols = clickhouse_client.execute("SELECT count(DISTINCT symbol) FROM stock_data")[0][0]
        stock_data_with_int_symbols = clickhouse_client.execute("SELECT count(DISTINCT symbol_int) FROM stock_data_with_int")[0][0]

        print(f"stock_data表不同股票数: {stock_data_symbols:,}")
        print(f"stock_data_with_int表不同股票数: {stock_data_with_int_symbols:,}")

        if stock_data_symbols != stock_data_with_int_symbols:
            print(f"警告: 两个表的不同股票数不同,相差{abs(stock_data_symbols - stock_data_with_int_symbols):,}个股票")
        else:
            print("两个表的不同股票数相同")

    except Exception as e:
        print(f"对比数据量时出错: {str(e)}")

# 执行数据量对比
compare_data_volume()

2.6.3. 索引对比
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def compare_indexes_fixed():
    try:
        # 首先检查system.data_skipping_indices表的结构
        columns = clickhouse_client.execute("""
        DESCRIBE TABLE system.data_skipping_indices
        """)

        print("system.data_skipping_indices表的列:")
        column_names = [col[0] for col in columns]
        for col in columns:
            print(f"列名: {col[0]}, 类型: {col[1]}")

        # 根据实际列名调整查询
        # 常见的列名可能是expression、expr或其他
        expression_column = None
        for possible_name in ['expression', 'expr', 'definition', 'data_type']:
            if possible_name in column_names:
                expression_column = possible_name
                break

        if not expression_column:
            print("\n无法找到表达式相关的列,将只显示索引名称")

            # 获取两个表的索引,只查询名称
            stock_data_indexes = clickhouse_client.execute(f"""
            SELECT name
            FROM system.data_skipping_indices
            WHERE table = 'stock_data' AND database = currentDatabase()
            """)

            stock_data_with_int_indexes = clickhouse_client.execute(f"""
            SELECT name
            FROM system.data_skipping_indices
            WHERE table = 'stock_data_with_int' AND database = currentDatabase()
            """)

            print("\nstock_data表索引:")
            for idx in stock_data_indexes:
                print(f"索引名称: {idx[0]}")

            print("\nstock_data_with_int表索引:")
            for idx in stock_data_with_int_indexes:
                print(f"索引名称: {idx[0]}")
        else:
            # 使用找到的表达式列名
            print(f"\n使用列 '{expression_column}' 作为表达式列")

            # 获取两个表的索引
            stock_data_indexes = clickhouse_client.execute(f"""
            SELECT name, {expression_column}
            FROM system.data_skipping_indices
            WHERE table = 'stock_data' AND database = currentDatabase()
            """)

            stock_data_with_int_indexes = clickhouse_client.execute(f"""
            SELECT name, {expression_column}
            FROM system.data_skipping_indices
            WHERE table = 'stock_data_with_int' AND database = currentDatabase()
            """)

            print("\nstock_data表索引:")
            for idx in stock_data_indexes:
                print(f"索引名称: {idx[0]}, 表达式: {idx[1]}")

            print("\nstock_data_with_int表索引:")
            for idx in stock_data_with_int_indexes:
                print(f"索引名称: {idx[0]}, 表达式: {idx[1]}")

        # 检查索引数量
        if len(stock_data_indexes) != len(stock_data_with_int_indexes):
            print(f"\n警告: 两个表的索引数量不同: stock_data有{len(stock_data_indexes)}个索引,stock_data_with_int有{len(stock_data_with_int_indexes)}个索引")
        else:
            print(f"\n两个表的索引数量相同: 各有{len(stock_data_indexes)}个索引")

    except Exception as e:
        print(f"对比索引时出错: {str(e)}")

# 执行修复后的索引对比
compare_indexes_fixed()

2.6.4. 排序键对比
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def compare_sort_keys():
    """对比两个表的排序键"""
    try:
        # 获取两个表的排序键
        stock_data_sort_keys = clickhouse_client.execute("""
        SELECT name, type
        FROM system.columns
        WHERE table = 'stock_data' AND database = currentDatabase() AND is_in_sorting_key = 1
        ORDER BY name
        """)

        stock_data_with_int_sort_keys = clickhouse_client.execute("""
        SELECT name, type
        FROM system.columns
        WHERE table = 'stock_data_with_int' AND database = currentDatabase() AND is_in_sorting_key = 1
        ORDER BY name
        """)

        print("stock_data表排序键:")
        for key in stock_data_sort_keys:
            print(f"列名: {key[0]}, 类型: {key[1]}")

        print("\nstock_data_with_int表排序键:")
        for key in stock_data_with_int_sort_keys:
            print(f"列名: {key[0]}, 类型: {key[1]}")

        # 检查排序键数量
        if len(stock_data_sort_keys) != len(stock_data_with_int_sort_keys):
            print(f"\n警告: 两个表的排序键数量不同: stock_data有{len(stock_data_sort_keys)}个排序键,stock_data_with_int有{len(stock_data_with_int_sort_keys)}个排序键")
        else:
            print(f"\n两个表的排序键数量相同: 各有{len(stock_data_sort_keys)}个排序键")

    except Exception as e:
        print(f"对比排序键时出错: {str(e)}")

# 执行排序键对比
compare_sort_keys()

2.6.5. 分区对比
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def compare_partitions_fixed():
    """对比两个表的分区(修复版)"""
    try:
        # 从system.parts表获取分区信息
        stock_data_partitions = clickhouse_client.execute("""
        SELECT partition, sum(rows) as row_count
        FROM system.parts
        WHERE table = 'stock_data' AND database = currentDatabase() AND active = 1
        GROUP BY partition
        ORDER BY partition
        """)

        stock_data_with_int_partitions = clickhouse_client.execute("""
        SELECT partition, sum(rows) as row_count
        FROM system.parts
        WHERE table = 'stock_data_with_int' AND database = currentDatabase() AND active = 1
        GROUP BY partition
        ORDER BY partition
        """)

        print("stock_data表分区:")
        for part in stock_data_partitions:
            print(f"分区: {part[0]}, 记录数: {part[1]:,}")

        print("\nstock_data_with_int表分区:")
        for part in stock_data_with_int_partitions:
            print(f"分区: {part[0]}, 记录数: {part[1]:,}")

        # 检查分区数量
        if len(stock_data_partitions) != len(stock_data_with_int_partitions):
            print(f"\n警告: 两个表的分区数量不同: stock_data有{len(stock_data_partitions)}个分区,stock_data_with_int有{len(stock_data_with_int_partitions)}个分区")
        else:
            print(f"\n两个表的分区数量相同: 各有{len(stock_data_partitions)}个分区")

        # 检查分区键
        try:
            # 获取表的分区键
            stock_data_partition_key = clickhouse_client.execute("""
            SELECT partition_key
            FROM system.tables
            WHERE name = 'stock_data' AND database = currentDatabase()
            """)[0][0]

            stock_data_with_int_partition_key = clickhouse_client.execute("""
            SELECT partition_key
            FROM system.tables
            WHERE name = 'stock_data_with_int' AND database = currentDatabase()
            """)[0][0]

            print(f"\nstock_data表分区键: {stock_data_partition_key}")
            print(f"stock_data_with_int表分区键: {stock_data_with_int_partition_key}")

            if stock_data_partition_key != stock_data_with_int_partition_key:
                print(f"\n警告: 两个表的分区键不同")
            else:
                print(f"\n两个表的分区键相同: {stock_data_partition_key}")
        except Exception as e:
            print(f"\n获取分区键信息时出错: {str(e)}")
            print("尝试使用SHOW CREATE TABLE获取表定义...")

            try:
                stock_data_create = clickhouse_client.execute("SHOW CREATE TABLE stock_data")[0][0]
                stock_data_with_int_create = clickhouse_client.execute("SHOW CREATE TABLE stock_data_with_int")[0][0]

                # 提取PARTITION BY子句
                import re
                stock_data_partition_match = re.search(r'PARTITION BY\s+([^\n]+)', stock_data_create)
                stock_data_with_int_partition_match = re.search(r'PARTITION BY\s+([^\n]+)', stock_data_with_int_create)

                if stock_data_partition_match and stock_data_with_int_partition_match:
                    stock_data_partition = stock_data_partition_match.group(1)
                    stock_data_with_int_partition = stock_data_with_int_partition_match.group(1)

                    print(f"\nstock_data表分区表达式: {stock_data_partition}")
                    print(f"stock_data_with_int表分区表达式: {stock_data_with_int_partition}")

                    if stock_data_partition != stock_data_with_int_partition:
                        print(f"\n警告: 两个表的分区表达式不同")
                    else:
                        print(f"\n两个表的分区表达式相同: {stock_data_partition}")
                else:
                    print("\n无法从表定义中提取分区表达式")
            except Exception as e:
                print(f"\n获取表定义时出错: {str(e)}")

    except Exception as e:
        print(f"对比分区时出错: {str(e)}")

# 执行修复后的分区对比
compare_partitions_fixed()

2.6.6. 表引擎对比
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def compare_engines():
    """对比两个表的引擎"""
    try:
        # 获取两个表的引擎
        stock_data_engine = clickhouse_client.execute("""
        SELECT engine
        FROM system.tables
        WHERE name = 'stock_data' AND database = currentDatabase()
        """)[0][0]

        stock_data_with_int_engine = clickhouse_client.execute("""
        SELECT engine
        FROM system.tables
        WHERE name = 'stock_data_with_int' AND database = currentDatabase()
        """)[0][0]

        print(f"stock_data表引擎: {stock_data_engine}")
        print(f"stock_data_with_int表引擎: {stock_data_with_int_engine}")

        if stock_data_engine != stock_data_with_int_engine:
            print(f"警告: 两个表的引擎不同")
        else:
            print("两个表的引擎相同")

    except Exception as e:
        print(f"对比引擎时出错: {str(e)}")

# 执行引擎对比
compare_engines()

2.6.7. 综合对比
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def assess_fairness():
    """综合评估测试公平性"""
    try:
        # 收集所有对比结果
        issues = []

        # 检查表结构
        stock_data_col_count = len([c for c in stock_data_columns if c[0] != 'symbol'])
        stock_data_with_int_col_count = len([c for c in stock_data_with_int_columns if c[0] != 'symbol_int'])
        if stock_data_col_count != stock_data_with_int_col_count:
            issues.append(f"两个表的列数不同 (不考虑symbol/symbol_int): stock_data有{stock_data_col_count}列,stock_data_with_int有{stock_data_with_int_col_count}列")

        # 检查数据量
        stock_data_count = clickhouse_client.execute("SELECT count() FROM stock_data")[0][0]
        stock_data_with_int_count = clickhouse_client.execute("SELECT count() FROM stock_data_with_int")[0][0]
        if stock_data_count != stock_data_with_int_count:
            issues.append(f"两个表的记录数不同: stock_data有{stock_data_count:,}条记录,stock_data_with_int有{stock_data_with_int_count:,}条记录")

        # 检查索引
        stock_data_indexes = clickhouse_client.execute("""
        SELECT count(*)
        FROM system.data_skipping_indices
        WHERE table = 'stock_data' AND database = currentDatabase()
        """)[0][0]

        stock_data_with_int_indexes = clickhouse_client.execute("""
        SELECT count(*)
        FROM system.data_skipping_indices
        WHERE table = 'stock_data_with_int' AND database = currentDatabase()
        """)[0][0]

        if stock_data_indexes != stock_data_with_int_indexes:
            issues.append(f"两个表的索引数量不同: stock_data有{stock_data_indexes}个索引,stock_data_with_int有{stock_data_with_int_indexes}个索引")

        # 检查排序键
        stock_data_sort_keys = clickhouse_client.execute("""
        SELECT count(*)
        FROM system.columns
        WHERE table = 'stock_data' AND database = currentDatabase() AND is_in_sorting_key = 1
        """)[0][0]

        stock_data_with_int_sort_keys = clickhouse_client.execute("""
        SELECT count(*)
        FROM system.columns
        WHERE table = 'stock_data_with_int' AND database = currentDatabase() AND is_in_sorting_key = 1
        """)[0][0]

        if stock_data_sort_keys != stock_data_with_int_sort_keys:
            issues.append(f"两个表的排序键数量不同: stock_data有{stock_data_sort_keys}个排序键,stock_data_with_int有{stock_data_with_int_sort_keys}个排序键")

        # 检查引擎
        stock_data_engine = clickhouse_client.execute("""
        SELECT engine
        FROM system.tables
        WHERE name = 'stock_data' AND database = currentDatabase()
        """)[0][0]

        stock_data_with_int_engine = clickhouse_client.execute("""
        SELECT engine
        FROM system.tables
        WHERE name = 'stock_data_with_int' AND database = currentDatabase()
        """)[0][0]

        if stock_data_engine != stock_data_with_int_engine:
            issues.append(f"两个表的引擎不同: stock_data使用{stock_data_engine},stock_data_with_int使用{stock_data_with_int_engine}")

        # 给出综合评估
        print("\n综合评估:")
        if issues:
            print("发现以下可能影响测试公平性的问题:")
            for i, issue in enumerate(issues):
                print(f"{i+1}. {issue}")
            print("\n建议: 在进行性能测试前,先解决这些问题,确保两个表除了symbol/symbol_int列的类型外,其他方面尽可能相同。")
        else:
            print("两个表在结构、数据量、索引等方面基本一致,适合进行公平的性能测试。")

    except Exception as e:
        print(f"评估公平性时出错: {str(e)}")

# 执行公平性评估
assess_fairness()

综合评估: 两个表在结构、数据量、索引等方面基本一致,适合进行公平的性能测试。

"太好了!"我看着最终的评估结果说道,"两个表的结构完全一致,只有股票代码的编码方式不同,这正是我们想要的测试环境。"

007 点点头:"是的,我们已经确保了测试的公平性和科学性。现在可以进行真正的性能测试了,看看整数编码是否真的能提升查询效率,以及提升幅度有多大。"

"这个结果对于我们设计高性能量化交易系统将有重要参考价值,"我补充道,"特别是在处理大规模数据时。"

2.7. 性能测试

"现在进入实验的关键阶段,"007 说道,"我们将设计一系列查询场景,测试两种编码方式在不同条件下的性能表现。"

我思考了一下:"我们应该测试单条记录查询、范围查询、批量查询等多种场景,这样才能全面了解不同编码方式的优缺点。"

"完全同意,"007 回应,"我已经设计了一个全面的测试函数,它会在相同条件下对两种编码方式进行公平比较。"

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def test_query_performance_fixed(num_tests=50):
    # 获取随机股票代码
    test_symbols = clickhouse_client.execute(f"""
    SELECT DISTINCT symbol FROM stock_data
    WHERE is_index = 0
    ORDER BY rand()
    LIMIT {num_tests}
    """)

    string_query_times = []
    int_query_times = []

    print(f"开始测试 {len(test_symbols)} 个股票的查询性能...")

    for symbol_tuple in tqdm(test_symbols, desc="测试查询性能"):
        symbol = symbol_tuple[0]

        # 获取对应的整数编码
        code, exchange = symbol.split('.')
        if exchange.upper() == 'SH':
            prefix = '1'
        elif exchange.upper() == 'SZ':
            prefix = '2'
        else:
            continue
        symbol_int = int(prefix + code)

        # 清理缓冲
        print("清理缓冲")
        clickhouse_client.execute("""
        SYSTEM DROP MARK CACHE;
        """)
        clickhouse_client.execute("""
        SYSTEM DROP UNCOMPRESSED CACHE;
        """)

        # 测试字符串查询
        start_time = time.time()
        clickhouse_client.execute(f"""
        SELECT
            symbol, frame, open, high, low, close, vol, amount
        FROM stock_data
        WHERE symbol = '{symbol}'
        ORDER BY frame
        LIMIT 1000
        """)
        string_time = time.time() - start_time
        string_query_times.append(string_time)

        # 清理缓冲
        print("清理缓冲")
        clickhouse_client.execute("""
        SYSTEM DROP MARK CACHE;
        """)
        clickhouse_client.execute("""
        SYSTEM DROP UNCOMPRESSED CACHE;
        """)

        # 测试整数查询
        start_time = time.time()
        clickhouse_client.execute(f"""
        SELECT
            symbol_int, frame, open, high, low, close, vol, amount
        FROM stock_data_with_int
        WHERE symbol_int = {symbol_int}
        ORDER BY frame
        LIMIT 1000
        """)

        int_time = time.time() - start_time
        int_query_times.append(int_time)

    # 计算平均查询时间
    avg_string_time = sum(string_query_times) / len(string_query_times)
    avg_int_time = sum(int_query_times) / len(int_query_times)

    print(f"字符串查询平均时间: {avg_string_time:.6f} 秒")
    print(f"整数查询平均时间: {avg_int_time:.6f} 秒")
    print(f"性能提升: {(avg_string_time - avg_int_time) / avg_string_time * 100:.2f}%")

    return {
        'string_times': string_query_times,
        'int_times': int_query_times,
        'avg_string_time': avg_string_time,
        'avg_int_time': avg_int_time,
        'improvement': (avg_string_time - avg_int_time) / avg_string_time * 100
    }

"""绘制性能对比图表"""
def plot_performance_comparison(results):
    plt.figure(figsize=(15, 12))

    # 1. 单条记录查询性能对比
    if 'single' in results:
        plt.subplot(2, 2, 1)
        data = [results['single']['string_times'], results['single']['int_times']]
        plt.boxplot(data, labels=['字符串查询', '整数查询'])
        plt.title('单条记录查询性能对比')
        plt.ylabel('查询时间 (秒)')
        plt.grid(True, linestyle='--', alpha=0.7)

    # 2. 日期范围查询性能对比
    if 'range' in results:
        plt.subplot(2, 2, 2)
        data = [results['range']['string_times'], results['range']['int_times']]
        plt.boxplot(data, labels=['字符串查询', '整数查询'])
        plt.title('日期范围查询性能对比')
        plt.ylabel('查询时间 (秒)')
        plt.grid(True, linestyle='--', alpha=0.7)

    # 3. 批量查询性能对比
    if 'batch' in results:
        plt.subplot(2, 2, 3)
        batch_sizes = [r['batch_size'] for r in results['batch']]
        string_times = [r['string_time'] for r in results['batch']]
        int_times = [r['int_time'] for r in results['batch']]

        x = np.arange(len(batch_sizes))
        width = 0.35

        plt.bar(x - width/2, string_times, width, label='字符串查询')
        plt.bar(x + width/2, int_times, width, label='整数查询')

        plt.xlabel('批量大小')
        plt.ylabel('查询时间 (秒)')
        plt.title('批量查询性能对比')
        plt.xticks(x, batch_sizes)
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)

    # 4. 性能提升百分比
    plt.subplot(2, 2, 4)
    categories = []
    improvements = []

    if 'single' in results:
        categories.append('单条查询')
        improvements.append(results['single']['improvement'])

    if 'range' in results:
        categories.append('范围查询')
        improvements.append(results['range']['improvement'])

    if 'batch' in results:
        categories.append('批量查询')
        # 使用平均提升
        avg_batch_improvement = sum(r['improvement'] for r in results['batch']) / len(results['batch'])
        improvements.append(avg_batch_improvement)

    if 'aggregate' in results:
        categories.append('聚合查询')
        improvements.append(results['aggregate']['improvement'])

    plt.bar(categories, improvements)
    plt.title('整数编码性能提升百分比')
    plt.ylabel('性能提升 (%)')
    plt.grid(True, linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.savefig('symbol_encoding_performance_comparison.png')
    plt.close()

    print("\n性能对比图已保存为 symbol_encoding_performance_comparison.png")

print("开始symbol编码性能测试...")

# 1. 清理缓冲
print("清理缓冲")
clickhouse_client.execute("""
SYSTEM DROP MARK CACHE;
""")
clickhouse_client.execute("""
SYSTEM DROP UNCOMPRESSED CACHE;
""")

# 2. 执行测试
results = test_query_performance_fixed(num_tests=50)

"测试结果出来了!"我兴奋地说道,看着屏幕上显示的性能数据。

007 分析道:"数据显示整数编码确实带来了显著的性能提升。字符串查询平均时间为 0.331844 秒,而整数查询平均时间为 0.259308 秒,性能提升了 21.86%。"

"这是一个相当可观的提升,"我思考着这个结果的实际意义,"在高频交易或大规模数据分析场景下,这种优化可能会带来显著的系统性能提升。"

"是的,"007 补充道,"而且这只是单次查询的结果。在实际系统中,可能会有数百万次查询,累积起来的性能差异将更加明显。"

2.8. 性能测试报告

"让我们将测试结果可视化,以便更直观地理解不同编码方式的性能差异,"007 建议道。

我点点头:"好主意,图表能帮助我们更清晰地看到性能差异的模式和趋势。"

"我已经设计了一个全面的可视化方案,"007 说,"包括条形图、箱线图和饼图等多种图表类型,从不同角度展示性能测试结果。"

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
# 查询性能可视化
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
from tqdm import tqdm
import seaborn as sns
from matplotlib.ticker import PercentFormatter

from matplotlib import font_manager
font_path = '/Volumes/share/data/WBQ/note/4_性能测试/SimHei.ttf'  # 替换为SimHei.ttf的实际路径
font_manager.fontManager.addfont(font_path)
plt.rcParams['font.family'] = 'SimHei'

def visualize_performance_results(results):
    """将性能测试结果可视化为图表"""
    if not results:
        print("没有测试结果可供可视化")
        return

    # 设置图表风格
    sns.set(style="whitegrid")
    plt.figure(figsize=(20, 15))

    # 1. 查询时间对比 - 条形图
    plt.subplot(2, 2, 1)
    test_names = [r['name'] for r in results]
    str_times = [r['avg_str_time'] for r in results]
    int_times = [r['avg_int_time'] for r in results]

    x = np.arange(len(test_names))
    width = 0.35

    plt.bar(x - width/2, str_times, width, label='字符串查询 (symbol)', color='#3498db',
            yerr=[r['std_str_time'] for r in results], capsize=5)
    plt.bar(x + width/2, int_times, width, label='整数查询 (symbol_int)', color='#e74c3c',
            yerr=[r['std_int_time'] for r in results], capsize=5)

    plt.xlabel('查询类型', fontsize=12)
    plt.ylabel('平均查询时间 (秒)', fontsize=12)
    plt.title('不同查询类型的平执行时间对比', fontsize=14, fontweight='bold')
    plt.xticks(x, [name if len(name) < 15 else name[:12] + '...' for name in test_names], rotation=45, ha='right')
    plt.legend(fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.7)

    # 2. 性能提升百分比 - 水平条形图
    plt.subplot(2, 2, 2)
    improvements = [r['improvement'] for r in results]
    colors = ['#2ecc71' if imp > 0 else '#e74c3c' for imp in improvements]

    y_pos = np.arange(len(test_names))
    plt.barh(y_pos, improvements, color=colors)
    plt.axvline(x=0, color='black', linestyle='-', alpha=0.7)
    plt.yticks(y_pos, [name if len(name) < 15 else name[:12] + '...' for name in test_names])
    plt.xlabel('性能提升 (%)', fontsize=12)
    plt.title('整数编码相对于字符串的性能提升', fontsize=14, fontweight='bold')
    plt.grid(True, linestyle='--', alpha=0.7)

    # 添加数值标签
    for i, v in enumerate(improvements):
        plt.text(v + (1 if v >= 0 else -1), i, f"{v:.1f}%",
                 va='center', fontweight='bold', color='black')

    # 3. 查询时间分布 - 箱线图
    plt.subplot(2, 2, 3)

    # 准备数据
    data_to_plot = []
    labels = []

    for r in results:
        data_to_plot.append(r['str_times'])
        data_to_plot.append(r['int_times'])
        labels.append(f"{r['name']} (str)")
        labels.append(f"{r['name']} (int)")

    # 绘制箱线图
    box = plt.boxplot(data_to_plot, patch_artist=True, labels=labels)

    # 设置颜色
    colors = []
    for i in range(len(data_to_plot)):
        if i % 2 == 0:  # 字符串查询
            colors.append('#3498db')
        else:  # 整数查询
            colors.append('#e74c3c')

    for patch, color in zip(box['boxes'], colors):
        patch.set_facecolor(color)

    plt.xticks(rotation=90)
    plt.ylabel('查询时间 (秒)', fontsize=12)
    plt.title('查询时间分布', fontsize=14, fontweight='bold')
    plt.grid(True, linestyle='--', alpha=0.7)

    # 4. 平均性能对比 - 饼图
    plt.subplot(2, 2, 4)

    # 计算平均性能提升
    avg_improvement = sum(improvements) / len(improvements)

    # 创建饼图数据
    if avg_improvement > 0:
        # 整数查询更快
        labels = ['整数查询更快', '字符串查询']
        sizes = [avg_improvement, 100 - avg_improvement]
        colors = ['#2ecc71', '#3498db']
        title = f'平均而言,整数查询比字符串查询快 {avg_improvement:.1f}%'
    else:
        # 字符串查询更快
        labels = ['字符串查询更快', '整数查询']
        sizes = [-avg_improvement, 100 + avg_improvement]
        colors = ['#3498db', '#e74c3c']
        title = f'平均而言,字符串查询比整数查询快 {-avg_improvement:.1f}%'

    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%',
            startangle=90, explode=(0.1, 0), shadow=True)
    plt.axis('equal')
    plt.title(title, fontsize=14, fontweight='bold')

    # 调整布局并保存
    plt.tight_layout()
    plt.savefig('query_performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

    print("性能测试结果图表已保存为 'query_performance_comparison.png'")

    # 创建详细的性能报告
    create_performance_report(results)

def create_performance_report(results):
    """创建详细的性能测试报告"""
    # 创建DataFrame
    report_data = []

    for r in results:
        report_data.append({
            '查询类型': r['name'],
            '字符串查询平均时间(秒)': r['avg_str_time'],
            '整数查询平均时间(秒)': r['avg_int_time'],
            '字符串查询标准差': r['std_str_time'],
            '整数查询标准差': r['std_int_time'],
            '性能提升(%)': r['improvement'],
            '字符串查询最小时间': min(r['str_times']),
            '字符串查询最大时间': max(r['str_times']),
            '整数查询最小时间': min(r['int_times']),
            '整数查询最大时间': max(r['int_times']),
        })

    df = pd.DataFrame(report_data)

    # 计算总体统计
    avg_str_time = df['字符串查询平均时间(秒)'].mean()
    avg_int_time = df['整数查询平均时间(秒)'].mean()
    avg_improvement = df['性能提升(%)'].mean()

    # 打印报告
    print("\n===== 性能测试详细报告 =====")
    print(f"测试场景数量: {len(results)}")
    print(f"总体平均字符串查询时间: {avg_str_time:.6f} 秒")
    print(f"总体平均整数查询时间: {avg_int_time:.6f} 秒")
    print(f"总体平均性能提升: {avg_improvement:.2f}%")

    # 打印每个场景的详细信息
    print("\n各场景详细数据:")
    print(df.to_string(index=False))

    # 保存报告到CSV
    df.to_csv('performance_test_report.csv', index=False)
    print("\n详细报告已保存到 'performance_test_report.csv'")

def run_performance_test(test_cases, num_iterations=5):
    """
    运行性能测试

    Args:
        test_cases: 测试用例列表,每个测试用例是一个字典,包含name, str_query和int_query
        num_iterations: 每个测试用例重复执行的次数

    Returns:
        测试结果列表
    """
    results = []

    for test_idx, test_case in enumerate(test_cases):
        str_query = test_case['str_query']
        int_query = test_case['int_query']
        test_name = test_case['name']

        print(f"\n测试 {test_idx+1}/{len(test_cases)}: {test_name}")

        str_times = []
        int_times = []

        for i in range(num_iterations):
            try:
                # 清除缓存
                if i == 0:  # 只在第一次迭代时清除缓存
                    try:
                        clickhouse_client.execute("SYSTEM DROP MARK CACHE")
                        clickhouse_client.execute("SYSTEM DROP UNCOMPRESSED CACHE")
                    except:
                        pass  # 忽略清除缓存的错误

                # 测试字符串查询
                start_time = time.time()
                clickhouse_client.execute(str_query)
                str_time = time.time() - start_time
                str_times.append(str_time)

                # 测试整数查询
                start_time = time.time()
                clickhouse_client.execute(int_query)
                int_time = time.time() - start_time
                int_times.append(int_time)

                print(f"  迭代 {i+1}/{num_iterations}: 字符串 {str_time:.6f}秒, 整数 {int_time:.6f}秒")
            except Exception as e:
                print(f"  迭代 {i+1}/{num_iterations} 出错: {str(e)}")
                print(f"  字符串查询: {str_query}")
                print(f"  整数查询: {int_query}")
                continue

        if not str_times or not int_times:
            print(f"  测试 {test_name} 失败,跳过")
            continue

        # 计算平均时间和标准差
        avg_str_time = sum(str_times) / len(str_times)
        avg_int_time = sum(int_times) / len(int_times)
        std_str_time = (sum((t - avg_str_time) ** 2 for t in str_times) / len(str_times)) ** 0.5
        std_int_time = (sum((t - avg_int_time) ** 2 for t in int_times) / len(int_times)) ** 0.5

        improvement = (avg_str_time - avg_int_time) / avg_str_time * 100

        print(f"  平均: 字符串 {avg_str_time:.6f}±{std_str_time:.6f}秒, 整数 {avg_int_time:.6f}±{std_int_time:.6f}秒")
        print(f"  性能提升: {improvement:.2f}%")

        results.append({
            'name': test_name,
            'str_times': str_times,
            'int_times': int_times,
            'avg_str_time': avg_str_time,
            'avg_int_time': avg_int_time,
            'std_str_time': std_str_time,
            'std_int_time': std_int_time,
            'improvement': improvement
        })

    return results

# 定义多种查询场景的测试用例
def create_test_cases():
    """创建多种查询场景的测试用例"""
    # 获取一些随机的股票代码用于测试
    symbols = clickhouse_client.execute("""
    SELECT DISTINCT symbol FROM stock_data
    ORDER BY rand()
    LIMIT 20
    """)

    symbol_ints = []
    for symbol_tuple in symbols:
        symbol = symbol_tuple[0]
        code, exchange = symbol.split('.')
        if exchange.upper() == 'SH':
            prefix = '1'
        elif exchange.upper() == 'SZ':
            prefix = '2'
        else:
            continue
        symbol_ints.append(int(prefix + code))

    # 确保我们有足够的股票代码
    if len(symbols) < 10 or len(symbol_ints) < 10:
        print("警告: 没有足够的股票代码用于测试")
        return []

    # 创建测试用例
    test_cases = [
        # 1. 单条记录精确查询
        {
            'name': '单条记录精确查询',
            'str_query': f"SELECT * FROM stock_data WHERE symbol = '{symbols[0][0]}' LIMIT 1000",
            'int_query': f"SELECT * FROM stock_data_with_int WHERE symbol_int = {symbol_ints[0]} LIMIT 1000"
        },

        # 2. 日期范围查询
        {
            'name': '日期范围查询',
            'str_query': f"SELECT * FROM stock_data WHERE symbol = '{symbols[1][0]}' AND frame BETWEEN '2016-01-01' AND '2016-12-31'",
            'int_query': f"SELECT * FROM stock_data_with_int WHERE symbol_int = {symbol_ints[1]} AND frame BETWEEN '2016-01-01' AND '2016-12-31'"
        },

        # 3. 批量查询 (IN条件)
        {
            'name': '批量查询 (5个股票)',
            'str_query': f"SELECT * FROM stock_data WHERE symbol IN ('{symbols[0][0]}', '{symbols[1][0]}', '{symbols[2][0]}', '{symbols[3][0]}', '{symbols[4][0]}') LIMIT 1000",
            'int_query': f"SELECT * FROM stock_data_with_int WHERE symbol_int IN ({symbol_ints[0]}, {symbol_ints[1]}, {symbol_ints[2]}, {symbol_ints[3]}, {symbol_ints[4]}) LIMIT 1000"
        },

        # 4. 聚合查询 (AVG)
        {
            'name': '聚合查询 (AVG)',
            'str_query': f"SELECT AVG(close) FROM stock_data WHERE symbol = '{symbols[2][0]}' GROUP BY toYYYYMM(frame)",
            'int_query': f"SELECT AVG(close) FROM stock_data_with_int WHERE symbol_int = {symbol_ints[2]} GROUP BY toYYYYMM(frame)"
        },

        # 5. 排序查询
        {
            'name': '排序查询',
            'str_query': f"SELECT * FROM stock_data WHERE symbol = '{symbols[3][0]}' ORDER BY frame DESC LIMIT 1000",
            'int_query': f"SELECT * FROM stock_data_with_int WHERE symbol_int = {symbol_ints[3]} ORDER BY frame DESC LIMIT 1000"
        },

        # 6. 复杂条件查询
        {
            'name': '复杂条件查询',
            'str_query': f"SELECT * FROM stock_data WHERE symbol = '{symbols[4][0]}' AND close > open AND vol > 1000000 LIMIT 1000",
            'int_query': f"SELECT * FROM stock_data_with_int WHERE symbol_int = {symbol_ints[4]} AND close > open AND vol > 1000000 LIMIT 1000"
        },

        # 7. JOIN查询
        {
            'name': 'JOIN查询',
            'str_query': f"""
            SELECT a.symbol, a.frame, a.close, b.close as prev_close
            FROM stock_data a
            LEFT JOIN stock_data b ON a.symbol = b.symbol AND b.frame = addDays(a.frame, -1)
            WHERE a.symbol = '{symbols[5][0]}'
            LIMIT 1000
            """,
            'int_query': f"""
            SELECT a.symbol_int, a.frame, a.close, b.close as prev_close
            FROM stock_data_with_int a
            LEFT JOIN stock_data_with_int b ON a.symbol_int = b.symbol_int AND b.frame = addDays(a.frame, -1)
            WHERE a.symbol_int = {symbol_ints[5]}
            LIMIT 1000
            """
        },

        # 8. 大批量查询 (更多股票)
        {
            'name': '大批量查询 (10个股票)',
            'str_query': "SELECT * FROM stock_data WHERE symbol IN (" + ", ".join([f"'{s[0]}'" for s in symbols[:10]]) + ") LIMIT 5000",
            'int_query': "SELECT * FROM stock_data_with_int WHERE symbol_int IN (" + ", ".join([str(s) for s in symbol_ints[:10]]) + ") LIMIT 5000"
        },

        # 9. 聚合查询 (COUNT)
        {
            'name': '聚合查询 (COUNT)',
            'str_query': f"SELECT COUNT(*) FROM stock_data WHERE symbol = '{symbols[6][0]}' GROUP BY toYear(frame)",
            'int_query': f"SELECT COUNT(*) FROM stock_data_with_int WHERE symbol_int = {symbol_ints[6]} GROUP BY toYear(frame)"
        },

        # 10. 复杂聚合查询
        {
            'name': '复杂聚合查询',
            'str_query': f"""
            SELECT
                toYear(frame) AS year,
                AVG(close) AS avg_close,
                MAX(high) AS max_high,
                MIN(low) AS min_low,
                SUM(vol) AS total_vol
            FROM stock_data
            WHERE symbol = '{symbols[7][0]}'
            GROUP BY year
            ORDER BY year
            """,
            'int_query': f"""
            SELECT
                toYear(frame) AS year,
                AVG(close) AS avg_close,
                MAX(high) AS max_high,
                MIN(low) AS min_low,
                SUM(vol) AS total_vol
            FROM stock_data_with_int
            WHERE symbol_int = {symbol_ints[7]}
            GROUP BY year
            ORDER BY year
            """
        }
    ]

    return test_cases


# 创建测试用例
test_cases = create_test_cases()

if not test_cases:
    print("无法创建测试用例,请检查数据库连接和表结构")
    exit()

# 运行性能测试
print(f"开始运行 {len(test_cases)} 个测试用例,每个用例重复 5 次...")
results = run_performance_test(test_cases, num_iterations=5)

if not results:
    print("测试失败,没有结果可供分析")
    exit()

# 可视化结果
visualize_performance_results(results)

3. 结论与建议

"测试结果非常明确,"007 总结道,"在所有测试场景中,整数编码的股票代码在查询性能上普遍优于字符串编码,平均提升约 25%。"

我认真分析着图表:"是的,而且在某些特定场景下,如单条查询、聚合查询、复杂聚合查询和排序查询,性能提升甚至超过了 30%。"

基于我们的实验结果,可以得出以下结论:

  1. 整数编码显著提升查询性能:在所有测试场景中,使用整数编码的 symbol_int 字段查询性能均优于字符串类型的 symbol 字段,平均提升 24.70%。

  2. 不同查询类型的性能差异:整数编码在复杂查询(如 JOIN 和聚合查询)中的性能优势更为明显,这可能是因为这些操作涉及更多的比较和索引查找操作。

  3. 稳定性提升:整数查询的标准差普遍小于字符串查询,表明整数编码不仅提高了平均性能,还提高了查询性能的稳定性和可预测性。

  4. 存储效率:整数类型通常比字符串类型占用更少的存储空间,特别是在大规模数据集中,这可能带来额外的存储和内存使用效率提升。

3.1 实际应用建议

基于实验结果,我们对量化交易系统的开发提出以下建议:

  1. 采用整数编码:在设计数据库架构时,建议将股票代码编码为整数类型存储,特别是在性能关键的应用中。

  2. 保留原始代码映射:虽然使用整数编码,但应保留原始代码与整数编码之间的映射关系,以便于调试和数据验证。

  3. 考虑编码规则的扩展性:设计编码规则时,应考虑未来可能需要支持更多交易所或特殊类型证券的情况。

"这个实验为我们的系统设计提供了重要的指导,"我总结道,"通过简单的编码转换,我们可以显著提升系统性能,这在高频交易和大规模数据分析场景下尤为重要。"

007 点点头:"是的,而且这种优化几乎没有额外成本,只需在数据入库时进行一次转换即可。在追求毫秒级甚至微秒级性能的量化交易系统中,这种优化可能带来关键的竞争优势。"

"下一步,我们可以将这种编码策略应用到实际的交易系统中,并在更多样化的查询场景下进行测试,"我补充道,"同时也可以探索其他可能的性能优化方向。"

"完全同意,"007 回应,"数据结构和存储优化是构建高性能量化交易系统的基础,我们的实验只是迈出了第一步。"

Augment随手记

Duckdb是一个年轻而迷人的数据库。它的备份可以简单到通过拷贝文件来完成 -- 但前提是,没有其它进程独占她。如果你的服务使用了duckdb,而且还在不停地读写她,你该怎么给她一个备份呢?

我们把这个问题抛给了Augment.

To Augment:

增加一个后台任务,每周备份一次数据库。

注意我们项目中,已经有过备份的逻辑(cli.py)但是,它是基于服务停止后,拷贝文件来实现的。如果我们通过后台任务来实现,在不停止服务的情况下,能否通过拷贝文件来实现备份呢?

另外要注意我们已经使用了apscheduler,你应该继续使用它。

这是 Augment的分析:

我理解您的需求。您希望增加一个每周自动备份数据库的后台任务,并且希望在不停止服务的情况下进行备份。让我们先查看现有的备份逻辑和 apscheduler 的使用情况,然后设计一个解决方案。

紧接着, Augment 调用一些命令行工具,查看了cli.py,scheduler.py, app.py,这样就对当前系统的状态了然于胸了。接下来,它开始寻找技术方案:

DuckDB 是一个嵌入式数据库,类似于 SQLite。在不停止服务的情况下备份 DuckDB 数据库,我们有几种可能的方法: 使用 DuckDB 的 EXPORT DATABASE 命令 使用 DuckDB 的 COPY 命令 使用文件系统级别的备份工具 让我们查看 DuckDB 是否支持在线备份:

它查看duckdb文档的方式是通过这个命令:

1
python -c "import duckdb; print(dir(duckdb))"

然后从一堆杂乱的输出中,找到了checkpoint方法,然后给出了实现方案:

基于以上信息,我们可以设计一个在不停止服务的情况下备份数据库的解决方案。DuckDB 支持 checkpoint 操作,这可以确保所有的更改都被写入磁盘。然后我们可以使用文件系统级别的备份来复制数据库文件。 以下是我的解决方案: 创建一个新的备份任务函数 将该任务添加到现有的 apscheduler 调度器中 设置每周执行一次 在备份过程中使用 DuckDB 的 checkpoint 功能确保数据一致性

最终的实现是:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def backup_database() -> None:
    """备份数据库

    在不停止服务的情况下备份数据库文件,使用DuckDB的checkpoint功能确保数据一致性
    """
    try:
        # 获取当前时间作为备份目录名
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")

        # 设置备份目录
        ...

        # 确保所有更改都写入磁盘
        conn = duckdb.connect(str(db_path), read_only=False)
        conn.execute("CHECKPOINT")
        conn.close()

        # 备份数据库文件
        backup_db_path = backup_dir / "provision.db"
        shutil.copy2(db_path, backup_db_path)

        # 保留最近的10个备份,删除旧的备份
        all_backups = sorted(backup_base.glob("*"), key=os.path.getmtime)
        if len(all_backups) > 10:
            for old_backup in all_backups[:-10]:
                if old_backup.is_dir():
                    shutil.rmtree(old_backup)
                    logger.info(f"Removed old backup: {old_backup}")

        logger.info(f"Database backup completed successfully: {backup_dir}")
    except Exception as e:
        logger.error(f"Failed to backup database: {e}")

Augment最后发挥了一下,只保留了最后10个备份。这相当于2个月前的备份会被删除。我并没有要求它这么做,但是,Augment的这份主动性很贴心。

致命的 ID -- DuckDB 中的 Returning 子句之谜

Duckdb是一个年轻但非常有潜力的数据库。但它也有桀骜不驯的一面:在一个普通的update语句执行时,出现了罕见的违反外键约束的问题。最终,依靠Augment这个强大的AI工具,我们找到了根本原因,并且通过坚实的实验验证了结论。

『华生,你是否曾思考过,在数据库的深处,隐藏着多少不为人知的秘密?』福尔摩斯放下手中的烟斗,凝视着窗外伦敦的雾霭。

『福尔摩斯,我承认数据库对我而言如同迷宫。』我诚实地回答,我正在记录福尔摩斯最新的冒险。

『今天早晨,一位焦虑的开发者前来求助,他遇到了一个令人费解的谜题。』福尔摩斯从桌上拿起那张写满 SQL 代码的纸条,『他的程序在执行一条看似无害的 UPDATE 语句时,突然抛出了外键约束错误。』

这是建表语句,这里有两张表,resources 和 resource_whitelist。resource_whitelist 表有一个外键引用 resources 表的 id 字段。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
CREATE SEQUENCE if not exists seq_resource_id START WITH 1 INCREMENT BY 1;
CREATE TABLE if not exists resources (
    id INTEGER PRIMARY KEY DEFAULT nextval('seq_resource_id'),
    course VARCHAR NOT NULL,
    resource VARCHAR NOT NULL,
    seq INTEGER NOT NULL,
    title VARCHAR NOT NULL,
    UNIQUE (course, rel_path)
);

CREATE SEQUENCE if not exists seq_resource_whitelist_id START WITH 1 INCREMENT BY 1;
CREATE TABLE if not exists resource_whitelist (
    id INTEGER PRIMARY KEY DEFAULT nextval('seq_resource_whitelist_id'),
    resource_id INTEGER NOT NULL,
    course VARCHAR NOT NULL,
    timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    FOREIGN KEY (customer_id) REFERENCES customers(id),
    FOREIGN KEY (resource_id) REFERENCES resources(id),
    UNIQUE (customer_id, resource_id)
);

Info

这是修改语句:

1
2
3
4
5
UPDATE resources
        SET seq = ?, title = ?, description = ?, 
        publish_date = ?, price = ?
        WHERE id = ?
        RETURNING id

『但这有什么奇怪的呢?』我问道,『外键约束不就是为了防止数据不一致吗?』

『我最初也是这么看』。福尔摩斯轻叹一口气,继续说道,『作为一个阅案无数的高手,我几乎立刻就回答了他:这个错误是因为在更新资源时违反了外键约束。错误信息表明 resource_id: 994 仍然被其他表中的外键引用,也就是是 resource_whitelist 表。我甚至还给出了修改方案。』

『但是,这位开发者并没有满足于我的修改方案,而是对我的答案进行了质疑』。福尔摩斯说道。

『质疑我们的福尔摩斯!』我不由得提高了音量。

『不幸的是,我的朋友』,福尔摩斯皱了皱眉,『这位开发者的质疑是有道理的。我的确应该看到更仔细一些,查出背后真正的元凶,再下结论。你知道,巴斯克维尔的猎犬案之后,我一直有点没恢复过来』。

巴斯克维尔的猎犬案涉及到一个古老的家族诅咒,传说中有一只巨大的恶魔猎犬在巴斯克维尔家族的领地上出没,专门袭击家族成员。这种超自然的元素使得案件一开始就笼罩在神秘和恐怖的氛围中,让调查变得异常困难,并且一度影响到了福尔摩斯的声誉。对此我完全赞同。

『问题的关键在于,华生』,福尔摩斯轻敲桌面,『这位开发者并未尝试更改任何主键,也没有删除任何记录。他仅仅是更新了一些无关紧要的字段,比如标题或描述。』

『那么,为什么会触发外键约束错误呢?』

『正是这一点引起了我的兴趣!』福尔摩斯站起身来,开始在房间里踱步,『我们面对的是 DuckDB,一个年轻而有趣的数据库系统。错误信息中提到了'foreign key limitations',这暗示着某种不寻常的行为。』

『你有什么理论吗,福尔摩斯?』

『我设计了一个实验,华生。』福尔摩斯拿出一张写满代码的纸,三个简单的测试案例,足以揭示真相。』

我凑近看了看那些 Python 代码,『看起来很复杂。』

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import duckdb
import logging

# 设置日志
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# 创建测试数据库
conn = duckdb.connect(':memory:')

# 创建测试表结构
conn.execute(『『『
CREATE TABLE parent (
    id INTEGER PRIMARY KEY,
    name VARCHAR
);

CREATE TABLE child (
    id INTEGER PRIMARY KEY,
    parent_id INTEGER,
    data VARCHAR,
    FOREIGN KEY (parent_id) REFERENCES parent(id)
);
『『『)

# 插入测试数据
conn.execute(INSERT INTO parent VALUES (1, 'Parent 1'), (2, 'Parent 2'))
conn.execute(INSERT INTO child VALUES (101, 1, 'Child 1'), (102, 2, 'Child 2'))

# 测试 1: 正常更新非键字段
try:
    logger.info(测试 1: 更新 parent 表的非键字段)
    conn.execute(UPDATE parent SET name = 'Updated Parent 1' WHERE id = 1)
    logger.info(测试 1 成功可以更新非键字段)
except Exception as e:
    logger.error(f测试 1 失败{e})

# 测试 2: 使用 RETURNING 子句更新
try:
    logger.info(测试 2: 使用 RETURNING 子句更新)
    result = conn.execute(UPDATE parent SET name = 'Updated Again' WHERE id = 1 RETURNING id).fetchall()
    logger.info(f测试 2 成功RETURNING 子句返回{result})
except Exception as e:
    logger.error(f测试 2 失败{e})

# 测试 3: 尝试更新被引用的主键
try:
    logger.info(测试 4: 尝试更新被引用的主键)
    conn.execute(UPDATE parent SET id = 3 WHERE id = 1)
    logger.info(测试 4 成功可以更新被引用的主键)
except Exception as e:
    logger.error(f测试 4 失败{e})
    if foreign key in str(e).lower():
        logger.info(确认更新被引用的主键会触发外键约束错误)

# 显示最终数据
parent_data = conn.execute(SELECT * FROM parent).fetchall()
child_data = conn.execute(SELECT * FROM child).fetchall()
logger.info(f最终 parent 表数据{parent_data})
logger.info(f最终 child 表数据{child_data})

『表面上看是如此。但真相往往隐藏在细节之中。』福尔摩斯微笑道,『第一个测试是普通的 UPDATE 操作,没有任何特殊子句。第二个测试添加了一个 RETURNING 子句。第三个测试则直接尝试更新被引用的主键。』

『结果如何?』

『啊,华生,结果令人着迷!『福尔摩斯的眼睛闪烁着兴奋的光芒,『第一个测试完美通过,证明普通的 UPDATE 操作可以正常工作。第三个测试如预期般失败,因为它确实违反了外键约束。』

『那么第二个测试呢?』

『第二个测试失败了,『福尔摩斯停顿了一下,『但第二个测试,华生,第二个测试揭示了真相!』

『怎么说?』

『带有 RETURNING 子句的 UPDATE 操作触发了外键约束错误,尽管它只是更新了非键字段!『福尔摩斯高声宣布,『这证明 DuckDB 在处理带 RETURNING 子句的 UPDATE 操作时,采用了不同的执行路径。它很可能在内部将 UPDATE 实现为'先 DELETE 再 INSERT'的组合操作!』

『太不可思议了,福尔摩斯!』

『初看之下,这似乎是个 bug。但从更深层次看,这是 DuckDB 实现细节的一个特性。『福尔摩斯重新坐下,『当使用 RETURNING 子句时,DuckDB 需要返回受影响行的信息。为了实现这一点,它可能选择了一种不同的执行策略,这种策略会触发完整的外键约束检查。』

『等一下!』我小声地喊起来,『这里还因为 resource 的主键是自增的!所以,当删除原记录,再新增记录时,尽管记录的语义没有改变,但它们的 id 字段却意外更新了』。

『你说得很对!』福尔摩斯微笑着说。

『那么解决方案是什么?』

『简单明了,华生。『福尔摩斯微笑着说,『要么避免在有外键引用的表上使用 RETURNING 子句,要么采用两步操作:先查询,再更新。』

『福尔摩斯,你总是能找到最简单的解决方案。』

『在数据库的世界里,华生,表面上看似简单的操作背后,往往隐藏着复杂的实现细节。』福尔摩斯拿起小提琴,拉出一段欢快的旋律,正如福尔摩斯常说:排除所有不可能的情况后,剩下的,无论多么不可思议,一定就是真相。

『那么这个案例可以称为什么呢?』我问道,准备为新的笔记命名。

『就叫它'RETURNING 子句之谜'吧,华生。』福尔摩斯微笑着回答,『一个小小的 SQL 子句,揭示了数据库引擎深处的秘密。』

窗外,伦敦的雾气渐渐散去,又一个数据库之谜被成功破解。

21天驯化AI打工仔 - 如何存储10亿个Symbol?

现在,我们需要设计一种通用的数据交换格式(Standard Quotes Exchange Protocol, SQEP)。这种格式的工作原理是:由数据生产者(因为只有生产者才了解原始数据的具体格式)将数据转换为这种标准格式,然后再将其推送到Redis中供消费者使用。


前言

第一天,我们讨论了如何从Tushare获取OHLC(开盘价、最高价、最低价、收盘价)数据和调整因子(adj_factor)。当时我们存储的数据结构如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
{
    "timestamp": "时间戳",
    "ts_code": "股票代码",
    "ohlc": {
        "ts_code": "股票代码",
        "open": "开盘价",
        "high": "最高价",
        "low": "最低价",
        "close": "收盘价",
        "vol": "成交量"
    }, 
    "adj_factor": {
        "ts_code": "股票代码",
        "trade_date": "交易日期",
        "adj_factor": "复权因子"
    }
}

现在,我们需要设计一种通用的数据交换格式(Standard Quotes Exchange Protocol, SQEP)。这种格式的工作原理是:由数据生产者(因为只有生产者才了解原始数据的具体格式)将数据转换为这种标准格式,然后再将其推送到Redis中供消费者使用。

1. SQEP-BAR-DAY 日线场景下的数据交换格式

SQEP-BAR-DAY 是标准行情交换协议(Standard Quotes Exchange Protocol)中用于日线数据的格式规范。该格式设计用于在不同系统组件间高效传输和处理股票日线数据,确保数据的一致性和互操作性。

1.1. 字段定义

SQEP-BAR-DAY 包含以下标准字段:

字段名 数据类型 说明
symbol str/int 股票代码。推荐使用整型编码以提高性能
frame datetime.date 交易日期
open float64 开盘价
high float64 最高价
low float64 最低价
close float64 收盘价
vol float64 成交量
amount float64 成交额
adjust float64 复权因子
st bool 是否为ST股票(可选扩展字段)
buy_limit float64 涨停价(可选扩展字段)
sell_limit float64 跌停价(可选扩展字段)

1.2. 编码约定

  1. 字段命名:使用frame而非datetimestamp,因为后两者在某些数据库中不适合作为列名。

  2. 股票代码编码:为提高查询性能,推荐将字符串格式的股票代码转换为整型:

  3. 上海证券交易所:000001.SH → 1000001
  4. 深圳证券交易所:000001.SZ → 2000001

这种编码方式最多可支持9个不同交易所(数字1-9,0不能用作前缀)。

1.3. 使用场景

SQEP-BAR-DAY 主要应用于:

  1. 数据生产者(如Tushare、QMT等数据源)将原始数据转换为标准格式
  2. 通过Redis等中间件在系统组件间传输
  3. 数据消费者(如分析引擎、回测系统)处理标准格式数据
  4. 存储到ClickHouse等时序数据库中进行长期保存

1.4. 007 的代码实现

既然规定好了日线场景下的数据交换格式,就可以让 007 设计代码实现了。

007 为我们提供了两个代码文件(sqep_bar_day_producer.pysqep_bar_day_consumer.py),简单修改后可以正常运行。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import redis
import tushare as ts
import json
from datetime import datetime
from typing import List, Dict, Tuple, Any, Union

# Tushare和Redis配置
TUSHARE_TOKEN = "YOUR TOKEN"
REDIS_HOST = "Your Redis Host"
REDIS_PORT = 6379
REDIS_PASSWORD = "Redis Password"  # 添加Redis密码
REDIS_QUEUE_NAME = "sqep_bar_day_queue"

# 初始化连接
pro = ts.pro_api(TUSHARE_TOKEN)
redis_client = redis.StrictRedis(
    host=REDIS_HOST, 
    port=REDIS_PORT, 
    password=REDIS_PASSWORD,  # 使用密码进行身份验证
    decode_responses=True
)

def encode_symbol(symbol: str) -> int:
    """将字符串格式的股票代码转换为整型编码

    Args:
        symbol: 股票代码,如 '000001.SZ' 或 '600519.SH'

    Returns:
        整型编码的股票代码,如 2000001 或 1600519
    """
    code, exchange = symbol.split('.')
    code = code.lstrip('0')  # 移除前导零,但保留至少一位数字
    if not code:
        code = '0'

    if exchange.upper() == 'SH':
        prefix = '1'
    elif exchange.upper() == 'SZ':
        prefix = '2'
    else:
        raise ValueError(f"不支持的交易所: {exchange}")

    return int(prefix + code)

def fetch_daily_data(ts_code: str, start_date: str, end_date: str) -> List[Dict[str, Any]]:
    """获取日线数据并转换为SQEP-BAR-DAY格式

    Args:
        ts_code: 股票代码
        start_date: 开始日期,格式为YYYYMMDD
        end_date: 结束日期,格式为YYYYMMDD

    Returns:
        SQEP-BAR-DAY格式的数据列表
    """
    try:
        # 获取OHLC数据
        df_daily = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)

        # 获取复权因子
        df_adj = pro.adj_factor(ts_code=ts_code, start_date=start_date, end_date=end_date)
        adj_dict = {row['trade_date']: row['adj_factor'] for _, row in df_adj.iterrows()}

        # 获取涨跌停价格(如果有高级API权限)
        try:
            df_limit = pro.limit_list(ts_code=ts_code, start_date=start_date, end_date=end_date)
            limit_dict = {row['trade_date']: (row['up_limit'], row['down_limit']) 
                         for _, row in df_limit.iterrows()}
        except:
            limit_dict = {}

        # 获取ST状态(如果有高级API权限)
        try:
            df_namechange = pro.namechange(ts_code=ts_code, start_date=start_date, end_date=end_date)
            st_dict = {row['start_date']: '*' in row['name'] or 'ST' in row['name'] 
                      for _, row in df_namechange.iterrows()}
        except:
            st_dict = {}

        # 转换为SQEP-BAR-DAY格式
        sqep_data = []
        for _, row in df_daily.iterrows():
            trade_date = row['trade_date']

            # 转换日期格式
            frame = datetime.strptime(trade_date, '%Y%m%d').date().isoformat()

            # 转换股票代码
            symbol = encode_symbol(ts_code)

            # 创建基本SQEP记录
            sqep_record = {
                'symbol': symbol,
                'frame': frame,
                'open': float(row['open']),
                'high': float(row['high']),
                'low': float(row['low']),
                'close': float(row['close']),
                'vol': float(row['vol']),
                'amount': float(row.get('amount', 0)),
                'adjust': float(adj_dict.get(trade_date, 1.0))
            }

            # 添加可选字段(如果存在)
            if trade_date in limit_dict:
                sqep_record['buy_limit'] = float(limit_dict[trade_date][0])
                sqep_record['sell_limit'] = float(limit_dict[trade_date][1])

            if trade_date in st_dict:
                sqep_record['st'] = st_dict[trade_date]

            sqep_data.append(sqep_record)

        return sqep_data

    except Exception as e:
        print(f"获取日线数据失败: {str(e)}")
        return []

def produce_sqep_data(ts_code_list: List[str], date_range: Tuple[str, str]):
    """生产SQEP-BAR-DAY数据并推送到Redis

    Args:
        ts_code_list: 股票代码列表
        date_range: 日期范围元组 (start_date, end_date)
    """
    start_date, end_date = date_range

    for ts_code in ts_code_list:
        # 获取并转换数据
        sqep_data = fetch_daily_data(ts_code, start_date, end_date)

        if not sqep_data:
            print(f"未获取到 {ts_code} 的数据")
            continue

        # 创建数据包
        data_package = {
            "timestamp": datetime.now().isoformat(),
            "source": "tushare",
            "data_type": "SQEP-BAR-DAY",
            "records": sqep_data
        }

        # 推送到Redis
        redis_client.lpush(REDIS_QUEUE_NAME, json.dumps(data_package))
        print(f"已推送SQEP-BAR-DAY数据: {ts_code} - {start_date}{end_date} ({len(sqep_data)}条)")

if __name__ == "__main__":
    # 示例参数
    STOCK_CODES = ["000001.SZ", "600519.SH"]
    DATE_RANGE = ("20230101", "20231231")

    produce_sqep_data(STOCK_CODES, DATE_RANGE)

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import redis
import json
from clickhouse_driver import Client
from datetime import datetime
from typing import Dict, List, Any

# 配置参数
REDIS_HOST = "8.217.201.221"
REDIS_PORT = 16379
REDIS_PASSWORD = "quantide666"  # 添加Redis密码
REDIS_QUEUE_NAME = "sqep_bar_day_queue"

CLICKHOUSE_HOST = "localhost"
CLICKHOUSE_PORT = 9000
CLICKHOUSE_DB = "default"

# 初始化 Redis 和 ClickHouse 客户端
redis_client = redis.StrictRedis(
    host=REDIS_HOST, 
    port=REDIS_PORT, 
    password=REDIS_PASSWORD,  # 使用密码进行身份验证
    decode_responses=True
)
clickhouse_client = Client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, database=CLICKHOUSE_DB)

def create_sqep_table_if_not_exists():
    """创建SQEP-BAR-DAY表(如果不存在)"""
    query = """
    CREATE TABLE IF NOT EXISTS sqep_bar_day (
        symbol Int32,
        frame Date,
        open Float64,
        high Float64,
        low Float64,
        close Float64,
        vol Float64,
        amount Float64,
        adjust Float64,
        st UInt8 DEFAULT 0,
        buy_limit Float64 DEFAULT 0,
        sell_limit Float64 DEFAULT 0
    ) ENGINE = MergeTree()
    PARTITION BY toYYYYMM(frame)
    ORDER BY (symbol, frame);
    """
    clickhouse_client.execute(query)
    print("已确保SQEP-BAR-DAY表存在")

def decode_symbol(encoded_symbol: int) -> str:
    """将整型编码的股票代码转换回字符串格式

    Args:
        encoded_symbol: 整型编码的股票代码,如 2000001

    Returns:
        字符串格式的股票代码,如 '000001.SZ'
    """
    encoded_str = str(encoded_symbol)
    prefix = encoded_str[0]
    code = encoded_str[1:]

    # 补齐6位数字
    code = code.zfill(6)

    if prefix == '1':
        exchange = 'SH'
    elif prefix == '2':
        exchange = 'SZ'
    else:
        raise ValueError(f"不支持的交易所前缀: {prefix}")

    return f"{code}.{exchange}"

def insert_to_clickhouse(data_package: Dict[str, Any]):
    """将SQEP-BAR-DAY数据插入到ClickHouse

    Args:
        data_package: 包含SQEP-BAR-DAY记录的数据包
    """
    records = data_package["records"]
    if not records:
        return 0

    # 准备插入数据
    values = []
    for record in records:
        # 准备基本字段
        row = (
            record["symbol"],
            datetime.fromisoformat(record["frame"]).date(),
            record["open"],
            record["high"],
            record["low"],
            record["close"],
            record["vol"],
            record["amount"],
            record["adjust"],
            int(record.get("st", False)),
            record.get("buy_limit", 0.0),
            record.get("sell_limit", 0.0)
        )
        values.append(row)

    # 执行插入
    query = """
    INSERT INTO sqep_bar_day (
        symbol, frame, open, high, low, close, vol, amount, adjust, st, buy_limit, sell_limit
    ) VALUES
    """

    clickhouse_client.execute(query, values)
    return len(values)

def consume_sqep_data():
    """消费SQEP-BAR-DAY数据"""
    # 确保表存在
    create_sqep_table_if_not_exists()

    print("启动SQEP-BAR-DAY数据消费者,等待队列数据...")
    while True:
        try:
            # 阻塞式获取队列数据
            result = redis_client.brpop(REDIS_QUEUE_NAME, timeout=1)
            if result is None:
                # 如果没有获取到数据,说明队列为空,退出循环
                print("Redis队列为空,停止消费数据。")
                break

            _, json_data = result
            data_package = json.loads(json_data)

            # 检查数据类型
            if data_package.get("data_type") != "SQEP-BAR-DAY":
                print(f"跳过非SQEP-BAR-DAY数据: {data_package.get('data_type')}")
                continue

            # 插入数据
            inserted_count = insert_to_clickhouse(data_package)

            # 获取第一条记录的股票代码用于显示
            if data_package["records"]:
                first_symbol = data_package["records"][0]["symbol"]
                symbol_str = decode_symbol(first_symbol)
                print(f"成功插入SQEP-BAR-DAY数据: {symbol_str} ({inserted_count}条)")
            else:
                print("数据包中没有记录")

        except Exception as e:
            print(f"数据处理异常: {str(e)}")
            continue

def query_sqep_data(symbol: str, start_date: str, end_date: str):
    """查询SQEP-BAR-DAY数据

    Args:
        symbol: 股票代码,如 '000001.SZ'
        start_date: 开始日期,格式为YYYY-MM-DD
        end_date: 结束日期,格式为YYYY-MM-DD

    Returns:
        查询结果列表
    """
    # 编码股票代码
    code, exchange = symbol.split('.')
    code = code.lstrip('0')
    if not code:
        code = '0'

    if exchange.upper() == 'SH':
        prefix = '1'
    elif exchange.upper() == 'SZ':
        prefix = '2'
    else:
        raise ValueError(f"不支持的交易所: {exchange}")

    encoded_symbol = int(prefix + code)

    # 执行查询
    query = f"""
    SELECT 
        symbol, frame, open, high, low, close, vol, amount, adjust, 
        st, buy_limit, sell_limit
    FROM sqep_bar_day
    WHERE symbol = {encoded_symbol} AND frame BETWEEN '{start_date}' AND '{end_date}'
    ORDER BY frame
    """

    result = clickhouse_client.execute(query)

    # 转换结果
    columns = [
        'symbol', 'frame', 'open', 'high', 'low', 'close', 'vol', 
        'amount', 'adjust', 'st', 'buy_limit', 'sell_limit'
    ]

    return [dict(zip(columns, row)) for row in result]

if __name__ == "__main__":
    consume_sqep_data()

1.5. 股票代码编码方式对查询的性能测试

接下来,我们将设计一个实验来测试股票代码编码方式对查询性能的影响。这个实验将比较字符串格式和整型编码格式在不同数据量下的查询性能差异。007 很 nice 地帮助我设计了一个实验方案:

1.5.1. 测试方案
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
import time
import random
import pandas as pd
import numpy as np
import sqlite3
import matplotlib.pyplot as plt
from typing import List, Tuple
import os

from matplotlib import font_manager
font_path = 'SimHei.ttf'  # 替换为SimHei.ttf的实际路径
font_manager.fontManager.addfont(font_path)
plt.rcParams['font.family'] = 'SimHei'

class SymbolEncodingBenchmark:
    """股票代码编码方式性能测试"""

    def __init__(self, db_path="symbol_benchmark.db"):
        """初始化基准测试

        Args:
            db_path: SQLite数据库文件路径
        """
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()

        # 创建测试表
        self._create_tables()

    def _create_tables(self):
        """创建测试表"""
        # 字符串格式表
        self.cursor.execute("""
        CREATE TABLE IF NOT EXISTS bar_day_str (
            symbol TEXT,
            frame TEXT,
            open REAL,
            high REAL,
            low REAL,
            close REAL,
            vol REAL,
            amount REAL,
            adjust REAL,
            PRIMARY KEY (symbol, frame)
        )
        """)

        # 整型编码表
        self.cursor.execute("""
        CREATE TABLE IF NOT EXISTS bar_day_int (
            symbol INTEGER,
            frame TEXT,
            open REAL,
            high REAL,
            low REAL,
            close REAL,
            vol REAL,
            amount REAL,
            adjust REAL,
            PRIMARY KEY (symbol, frame)
        )
        """)

        # 创建索引
        self.cursor.execute("CREATE INDEX IF NOT EXISTS idx_str_symbol ON bar_day_str (symbol)")
        self.cursor.execute("CREATE INDEX IF NOT EXISTS idx_int_symbol ON bar_day_int (symbol)")

        self.conn.commit()

    @staticmethod
    def encode_symbol(symbol: str) -> int:
        """将字符串格式的股票代码转换为整型编码"""
        code, exchange = symbol.split('.')
        code = code.lstrip('0')  # 移除前导零
        if not code:
            code = '0'

        if exchange.upper() == 'SH':
            prefix = '1'
        elif exchange.upper() == 'SZ':
            prefix = '2'
        else:
            raise ValueError(f"不支持的交易所: {exchange}")

        return int(prefix + code)

    def generate_test_data(self, num_symbols: int, days_per_symbol: int) -> pd.DataFrame:
        """生成测试数据

        Args:
            num_symbols: 股票数量
            days_per_symbol: 每只股票的交易日数量

        Returns:
            包含测试数据的DataFrame
        """
        # 生成股票代码
        sh_symbols = [f"{str(i).zfill(6)}.SH" for i in range(num_symbols // 2)]
        sz_symbols = [f"{str(i).zfill(6)}.SZ" for i in range(num_symbols // 2)]
        symbols = sh_symbols + sz_symbols

        # 生成日期范围
        start_date = pd.Timestamp('2020-01-01')
        dates = [start_date + pd.Timedelta(days=i) for i in range(days_per_symbol)]

        # 生成数据
        data = []
        for symbol in symbols:
            for date in dates:
                open_price = random.uniform(10, 100)
                high = open_price * random.uniform(1, 1.1)
                low = open_price * random.uniform(0.9, 1)
                close = random.uniform(low, high)

                data.append({
                    'symbol': symbol,
                    'frame': date.strftime('%Y-%m-%d'),
                    'open': open_price,
                    'high': high,
                    'low': low,
                    'close': close,
                    'vol': random.uniform(10000, 1000000),
                    'amount': random.uniform(1000000, 100000000),
                    'adjust': random.uniform(0.8, 1.2)
                })

        return pd.DataFrame(data)

    def load_test_data(self, df: pd.DataFrame):
        """加载测试数据到数据库

        Args:
            df: 包含测试数据的DataFrame
        """
        # 清空表
        self.cursor.execute("DELETE FROM bar_day_str")
        self.cursor.execute("DELETE FROM bar_day_int")

        # 插入字符串格式数据
        for _, row in df.iterrows():
            self.cursor.execute(
                "INSERT INTO bar_day_str VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
                (
                    row['symbol'],
                    row['frame'],
                    row['open'],
                    row['high'],
                    row['low'],
                    row['close'],
                    row['vol'],
                    row['amount'],
                    row['adjust']
                )
            )

        # 插入整型编码数据
        for _, row in df.iterrows():
            encoded_symbol = self.encode_symbol(row['symbol'])
            self.cursor.execute(
                "INSERT INTO bar_day_int VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
                (
                    encoded_symbol,
                    row['frame'],
                    row['open'],
                    row['high'],
                    row['low'],
                    row['close'],
                    row['vol'],
                    row['amount'],
                    row['adjust']
                )
            )

        self.conn.commit()

    def run_query_benchmark(self, num_queries: int) -> Tuple[List[float], List[float]]:
        """运行查询基准测试

        Args:
            num_queries: 查询次数

        Returns:
            字符串格式和整型编码的查询时间列表
        """
        # 获取所有股票代码
        self.cursor.execute("SELECT DISTINCT symbol FROM bar_day_str")
        str_symbols = [row[0] for row in self.cursor.fetchall()]

        self.cursor.execute("SELECT DISTINCT symbol FROM bar_day_int")
        int_symbols = [row[0] for row in self.cursor.fetchall()]

        # 运行查询测试
        str_times = []
        int_times = []

        for _ in range(num_queries):
            # 随机选择一个股票代码
            str_symbol = random.choice(str_symbols)
            int_symbol = self.encode_symbol(str_symbol)

            # 测试字符串格式查询
            start_time = time.time()
            self.cursor.execute(
                "SELECT * FROM bar_day_str WHERE symbol = ?",
                (str_symbol,)
            )
            results = self.cursor.fetchall()
            str_times.append(time.time() - start_time)

            # 测试整型编码查询
            start_time = time.time()
            self.cursor.execute(
                "SELECT * FROM bar_day_int WHERE symbol = ?",
                (int_symbol,)
            )
            results = self.cursor.fetchall()
            int_times.append(time.time() - start_time)

        return str_times, int_times

    def run_range_query_benchmark(self, num_queries: int) -> Tuple[List[float], List[float]]:
        """运行范围查询基准测试

        Args:
            num_queries: 查询次数

        Returns:
            字符串格式和整型编码的查询时间列表
        """
        # 获取所有交易所
        exchanges = ['SH', 'SZ']

        # 运行查询测试
        str_times = []
        int_times = []

        for _ in range(num_queries):
            # 随机选择一个交易所
            exchange = random.choice(exchanges)

            # 测试字符串格式查询
            start_time = time.time()
            self.cursor.execute(
                "SELECT * FROM bar_day_str WHERE symbol LIKE ?",
                (f"%.{exchange}",)
            )
            results = self.cursor.fetchall()
            str_times.append(time.time() - start_time)

            # 测试整型编码查询
            prefix = 1 if exchange == 'SH' else 2
            start_time = time.time()
            self.cursor.execute(
                "SELECT * FROM bar_day_int WHERE symbol >= ? AND symbol < ?",
                (prefix * 1000000, (prefix + 1) * 1000000)
            )
            results = self.cursor.fetchall()
            int_times.append(time.time() - start_time)

        return str_times, int_times

    def run_full_benchmark(self, data_sizes: List[int], days_per_symbol: int = 252, num_queries: int = 100):
        """运行完整基准测试

        Args:
            data_sizes: 测试的股票数量列表
            days_per_symbol: 每只股票的交易日数量
            num_queries: 每次测试的查询次数
        """
        results = {
            'data_size': [],
            'str_query_avg': [],
            'int_query_avg': [],
            'str_range_avg': [],
            'int_range_avg': []
        }

        for size in data_sizes:
            print(f"测试数据量: {size}只股票 × {days_per_symbol}天 = {size * days_per_symbol}条记录")

            # 生成并加载测试数据
            df = self.generate_test_data(size, days_per_symbol)
            self.load_test_data(df)

            # 运行查询测试
            str_times, int_times = self.run_query_benchmark(num_queries)
            str_range_times, int_range_times = self.run_range_query_benchmark(num_queries)

            # 记录结果
            results['data_size'].append(size * days_per_symbol)
            results['str_query_avg'].append(np.mean(str_times) * 1000)  # 转换为毫秒
            results['int_query_avg'].append(np.mean(int_times) * 1000)
            results['str_range_avg'].append(np.mean(str_range_times) * 1000)
            results['int_range_avg'].append(np.mean(int_range_times) * 1000)

            print(f"  单条查询 - 字符串格式: {results['str_query_avg'][-1]:.2f}ms, 整型编码: {results['int_query_avg'][-1]:.2f}ms")
            print(f"  范围查询 - 字符串格式: {results['str_range_avg'][-1]:.2f}ms, 整型编码: {results['int_range_avg'][-1]:.2f}ms")
            print(f"  性能提升 - 单条查询: {(results['str_query_avg'][-1] / results['int_query_avg'][-1]):.2f}x, 范围查询: {(results['str_range_avg'][-1] / results['int_range_avg'][-1]):.2f}x")
            print()

        # 绘制结果图表
        self._plot_results(results)

        return results

    def _plot_results(self, results: dict):
        """绘制测试结果图表

        Args:
            results: 测试结果字典
        """
        plt.figure(figsize=(15, 10))

        # 单条查询性能对比
        plt.subplot(2, 2, 1)
        plt.plot(results['data_size'], results['str_query_avg'], 'o-', label='字符串格式')
        plt.plot(results['data_size'], results['int_query_avg'], 'o-', label='整型编码')
        plt.title('单条查询性能对比')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('平均查询时间 (毫秒)')
        plt.legend()
        plt.grid(True)

        # 范围查询性能对比
        plt.subplot(2, 2, 2)
        plt.plot(results['data_size'], results['str_range_avg'], 'o-', label='字符串格式')
        plt.plot(results['data_size'], results['int_range_avg'], 'o-', label='整型编码')
        plt.title('范围查询性能对比')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('平均查询时间 (毫秒)')
        plt.legend()
        plt.grid(True)

        # 性能提升比例
        plt.subplot(2, 2, 3)
        speedup_query = [s / i for s, i in zip(results['str_query_avg'], results['int_query_avg'])]
        speedup_range = [s / i for s, i in zip(results['str_range_avg'], results['int_range_avg'])]
        plt.plot(results['data_size'], speedup_query, 'o-', label='单条查询')
        plt.plot(results['data_size'], speedup_range, 'o-', label='范围查询')
        plt.title('整型编码性能提升比例')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('性能提升 (倍)')
        plt.legend()
        plt.grid(True)

        # 查询时间与数据量的关系
        plt.subplot(2, 2, 4)
        plt.loglog(results['data_size'], results['str_query_avg'], 'o-', label='字符串-单条')
        plt.loglog(results['data_size'], results['int_query_avg'], 'o-', label='整型-单条')
        plt.loglog(results['data_size'], results['str_range_avg'], 'o-', label='字符串-范围')
        plt.loglog(results['data_size'], results['int_range_avg'], 'o-', label='整型-范围')
        plt.title('查询时间与数据量关系 (对数坐标)')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('平均查询时间 (毫秒)')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.savefig('symbol_encoding_benchmark.png')
        plt.close()

    def cleanup(self):
        """清理测试资源"""
        self.conn.close()
        if os.path.exists(self.db_path):
            os.remove(self.db_path)


if __name__ == "__main__":
    # 运行基准测试
    benchmark = SymbolEncodingBenchmark()

    # 测试不同数据量
    data_sizes = [100, 500, 1000, 2000, 5000]
    results = benchmark.run_full_benchmark(data_sizes)

    # 输出总结
    print("测试总结:")
    print(f"数据量范围: {min(results['data_size'])} - {max(results['data_size'])}条记录")
    print(f"单条查询平均性能提升: {np.mean([s / i for s, i in zip(results['str_query_avg'], results['int_query_avg'])]):.2f}倍")
    print(f"范围查询平均性能提升: {np.mean([s / i for s, i in zip(results['str_range_avg'], results['int_range_avg'])]):.2f}倍")

    # 清理资源
    benchmark.cleanup()

    print("\n测试完成,结果已保存到 symbol_encoding_benchmark.png")
  1. 测试环境
    • 使用SQLite数据库作为测试平台(易于部署且无需额外配置)
    • 创建两个表:bar_day_str(字符串格式)和 bar_day_int(整型编码)
    • 为两个表的 symbol 字段创建索引,确保公平比较
  2. 测试数据
    • 生成不同数量的股票代码(上交所和深交所各半)
    • 为每只股票生成多天的交易数据
    • 数据量从小到大逐步增加(100到5000只股票)
  3. 测试场景
    • 单条查询:根据特定股票代码查询所有交易记录
    • 范围查询:查询特定交易所的所有股票记录
  4. 性能指标
    • 查询响应时间(毫秒)
    • 性能提升比例(字符串格式时间/整型编码时间)
  5. 结果分析
    • 绘制四个图表展示测试结果:
    • 单条查询性能对比
    • 范围查询性能对比
    • 整型编码性能提升比例
    • 查询时间与数据量关系(对数坐标)
1.5.2. 测试结果

运行 007 的测试方案,可以得到如下的股票代码编码方式性能测试结果。

1.5.2.1. 测试数据概览
数据量 股票数量 每只股票天数 总记录数
100 252 25,200
中小 500 252 126,000
1,000 252 252,000
中大 2,000 252 504,000
5,000 252 1,260,000
1.5.2.2. 查询性能对比

数据量(记录数) 单条查询时间(ms) 性能提升 范围查询时间(ms) 性能提升
字符串格式 整型编码 倍数 字符串格式 整型编码 倍数
25,200 0.14 0.13 1.04x 7.68 0.42 18.34x
126,000 0.14 0.13 1.05x 37.34 2.30 16.24x
252,000 0.14 0.14 1.03x 76.20 4.95 15.39x
504,000 0.16 0.19 0.83x 148.13 10.05 14.73x
1,260,000 0.19 0.28 0.70x 377.34 24.71 15.27x
平均 - - 0.93x - - 15.99x

1.5.2.3. 结果分析
  1. 单条查询性能
  2. 在小到中等数据量(25,200-252,000条)下,整型编码略优于字符串格式(1.03x-1.05x)
  3. 在较大数据量(504,000-1,260,000条)下,字符串格式反而略优于整型编码(0.70x-0.83x)
  4. 整体来看,单条查询的性能差异不显著,平均提升为0.93倍

  5. 范围查询性能

  6. 在所有数据量级下,整型编码都显著优于字符串格式
  7. 性能提升倍数在 14.73x-18.34x 之间
  8. 平均性能提升达到 15.99

  9. 随数据量增长的趋势

  10. 单条查询:随着数据量增加,整型编码的相对优势逐渐减弱
  11. 范围查询:整型编码的巨大优势在各数据量级下保持稳定
1.5.2.4. 结论
  1. 对于单条查询:两种编码方式性能相近,在实际应用中差异不明显
  2. 对于范围查询:整型编码提供了显著的性能优势,平均快约16倍
  3. 推荐使用场景
  4. 如果系统中范围查询较为频繁(如按交易所筛选股票),强烈推荐使用整型编码
  5. 如果系统主要进行单条查询,编码方式的选择影响不大,可以根据其他因素决定

  6. 其他考虑因素

  7. 整型编码节省存储空间
  8. 整型编码便于进行数值运算和比较
  9. 字符串格式更直观,调试时更容易理解

2. SQEP-BAR-MINITE 分钟线场景下的数据交换格式

同上,但没有复权因子。这样,无论将来我们从哪个数据源获得的数据,消费者一端的代码都不需要更改。这里,我和 007 将设计一个性能测试方案,比较JSON(带key)和CSV(不带key)两种数据交换格式的性能差异。

2.1. 测试方案

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import time
import json
import csv
import io
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Any, Tuple
import os

from matplotlib import font_manager
font_path = 'SimHei.ttf'  # 替换为SimHei.ttf的实际路径
font_manager.fontManager.addfont(font_path)
plt.rcParams['font.family'] = 'SimHei'

class DataFormatBenchmark:
    """SQEP数据格式性能测试:JSON vs CSV"""

    def __init__(self):
        """初始化基准测试"""
        # 定义SQEP-BAR-DAY字段顺序(CSV格式需要)
        self.field_order = [
            'symbol', 'frame', 'open', 'high', 'low', 
            'close', 'vol', 'amount', 'adjust'
        ]

    def generate_test_data(self, num_records: int) -> List[Dict[str, Any]]:
        """生成测试数据

        Args:
            num_records: 记录数量

        Returns:
            包含测试数据的记录列表
        """
        data = []

        # 生成股票代码 - 确保至少有1只股票
        num_symbols = max(1, min(num_records // 252, 5000))
        symbols = []
        for i in range(num_symbols):
            exchange = 'SH' if i % 2 == 0 else 'SZ'
            symbols.append(f"{str(i).zfill(6)}.{exchange}")

        # 生成日期范围 - 确保至少有1天
        days_needed = max(1, num_records // len(symbols))
        start_date = pd.Timestamp('2020-01-01')
        dates = [start_date + pd.Timedelta(days=i) for i in range(min(days_needed, 365))]

        # 生成数据
        for symbol in symbols:
            for date in dates:
                if len(data) >= num_records:
                    break

                open_price = random.uniform(10, 100)
                high = open_price * random.uniform(1, 1.1)
                low = open_price * random.uniform(0.9, 1)
                close = random.uniform(low, high)

                data.append({
                    'symbol': symbol,
                    'frame': date.strftime('%Y-%m-%d'),
                    'open': round(open_price, 2),
                    'high': round(high, 2),
                    'low': round(low, 2),
                    'close': round(close, 2),
                    'vol': round(random.uniform(10000, 1000000), 0),
                    'amount': round(random.uniform(1000000, 100000000), 0),
                    'adjust': round(random.uniform(0.8, 1.2), 4)
                })

        return data[:num_records]

    def encode_json(self, data: List[Dict[str, Any]]) -> str:
        """将数据编码为JSON格式

        Args:
            data: 记录列表

        Returns:
            JSON字符串
        """
        return json.dumps({
            "timestamp": pd.Timestamp.now().isoformat(),
            "source": "benchmark",
            "data_type": "SQEP-BAR-DAY",
            "records": data
        })

    def decode_json(self, json_str: str) -> List[Dict[str, Any]]:
        """将JSON字符串解码为数据

        Args:
            json_str: JSON字符串

        Returns:
            记录列表
        """
        data = json.loads(json_str)
        return data["records"]

    def encode_csv(self, data: List[Dict[str, Any]]) -> str:
        """将数据编码为CSV格式

        Args:
            data: 记录列表

        Returns:
            CSV字符串
        """
        output = io.StringIO()
        writer = csv.writer(output)

        # 写入元数据行
        writer.writerow([
            pd.Timestamp.now().isoformat(),
            "benchmark",
            "SQEP-BAR-DAY",
            len(data)
        ])

        # 写入数据行
        for record in data:
            row = [record[field] for field in self.field_order]
            writer.writerow(row)

        return output.getvalue()

    def decode_csv(self, csv_str: str) -> List[Dict[str, Any]]:
        """将CSV字符串解码为数据

        Args:
            csv_str: CSV字符串

        Returns:
            记录列表
        """
        input_file = io.StringIO(csv_str)
        reader = csv.reader(input_file)

        # 读取元数据行
        metadata = next(reader)
        timestamp, source, data_type, num_records = metadata

        # 读取数据行
        records = []
        for row in reader:
            record = {field: value for field, value in zip(self.field_order, row)}

            # 转换数据类型
            record['open'] = float(record['open'])
            record['high'] = float(record['high'])
            record['low'] = float(record['low'])
            record['close'] = float(record['close'])
            record['vol'] = float(record['vol'])
            record['amount'] = float(record['amount'])
            record['adjust'] = float(record['adjust'])

            records.append(record)

        return records

    def run_encoding_benchmark(self, data: List[Dict[str, Any]], num_iterations: int = 100) -> Tuple[float, float]:
        """运行编码基准测试

        Args:
            data: 测试数据
            num_iterations: 迭代次数

        Returns:
            JSON和CSV的平均编码时间(毫秒)
        """
        json_times = []
        csv_times = []

        for _ in range(num_iterations):
            # 测试JSON编码
            start_time = time.time()
            json_str = self.encode_json(data)
            json_times.append(time.time() - start_time)

            # 测试CSV编码
            start_time = time.time()
            csv_str = self.encode_csv(data)
            csv_times.append(time.time() - start_time)

        # 计算平均时间(毫秒)
        json_avg = np.mean(json_times) * 1000
        csv_avg = np.mean(csv_times) * 1000

        return json_avg, csv_avg

    def run_decoding_benchmark(self, data: List[Dict[str, Any]], num_iterations: int = 100) -> Tuple[float, float]:
        """运行解码基准测试

        Args:
            data: 测试数据
            num_iterations: 迭代次数

        Returns:
            JSON和CSV的平均解码时间(毫秒)
        """
        # 先编码数据
        json_str = self.encode_json(data)
        csv_str = self.encode_csv(data)

        json_times = []
        csv_times = []

        for _ in range(num_iterations):
            # 测试JSON解码
            start_time = time.time()
            self.decode_json(json_str)
            json_times.append(time.time() - start_time)

            # 测试CSV解码
            start_time = time.time()
            self.decode_csv(csv_str)
            csv_times.append(time.time() - start_time)

        # 计算平均时间(毫秒)
        json_avg = np.mean(json_times) * 1000
        csv_avg = np.mean(csv_times) * 1000

        return json_avg, csv_avg

    def measure_size(self, data: List[Dict[str, Any]]) -> Tuple[int, int]:
        """测量编码后的数据大小

        Args:
            data: 测试数据

        Returns:
            JSON和CSV的字节大小
        """
        json_str = self.encode_json(data)
        csv_str = self.encode_csv(data)

        return len(json_str.encode('utf-8')), len(csv_str.encode('utf-8'))

    def run_full_benchmark(self, data_sizes: List[int], num_iterations: int = 100):
        """运行完整基准测试

        Args:
            data_sizes: 测试的记录数量列表
            num_iterations: 每次测试的迭代次数
        """
        results = {
            'data_size': [],
            'json_encode_time': [],
            'csv_encode_time': [],
            'json_decode_time': [],
            'csv_decode_time': [],
            'json_size': [],
            'csv_size': []
        }

        for size in data_sizes:
            print(f"测试数据量: {size}条记录")

            # 生成测试数据
            data = self.generate_test_data(size)

            # 运行编码测试
            json_encode_time, csv_encode_time = self.run_encoding_benchmark(data, num_iterations)

            # 运行解码测试
            json_decode_time, csv_decode_time = self.run_decoding_benchmark(data, num_iterations)

            # 测量数据大小
            json_size, csv_size = self.measure_size(data)

            # 记录结果
            results['data_size'].append(size)
            results['json_encode_time'].append(json_encode_time)
            results['csv_encode_time'].append(csv_encode_time)
            results['json_decode_time'].append(json_decode_time)
            results['csv_decode_time'].append(csv_decode_time)
            results['json_size'].append(json_size)
            results['csv_size'].append(csv_size)

            print(f"  编码时间 - JSON: {json_encode_time:.2f}ms, CSV: {csv_encode_time:.2f}ms")
            print(f"  解码时间 - JSON: {json_decode_time:.2f}ms, CSV: {csv_decode_time:.2f}ms")
            print(f"  数据大小 - JSON: {json_size/1024:.2f}KB, CSV: {csv_size/1024:.2f}KB")
            print(f"  性能比较 - 编码: JSON/CSV = {json_encode_time/csv_encode_time:.2f}x, 解码: JSON/CSV = {json_decode_time/csv_decode_time:.2f}x")
            print(f"  大小比较 - JSON/CSV = {json_size/csv_size:.2f}x")
            print()

        # 绘制结果图表
        self._plot_results(results)

        return results

    def _plot_results(self, results: dict):
        """绘制测试结果图表

        Args:
            results: 测试结果字典
        """
        plt.figure(figsize=(15, 12))

        # 编码时间对比
        plt.subplot(3, 2, 1)
        plt.plot(results['data_size'], results['json_encode_time'], 'o-', label='JSON')
        plt.plot(results['data_size'], results['csv_encode_time'], 'o-', label='CSV')
        plt.title('编码时间对比')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('平均编码时间 (毫秒)')
        plt.legend()
        plt.grid(True)

        # 解码时间对比
        plt.subplot(3, 2, 2)
        plt.plot(results['data_size'], results['json_decode_time'], 'o-', label='JSON')
        plt.plot(results['data_size'], results['csv_decode_time'], 'o-', label='CSV')
        plt.title('解码时间对比')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('平均解码时间 (毫秒)')
        plt.legend()
        plt.grid(True)

        # 数据大小对比
        plt.subplot(3, 2, 3)
        plt.plot(results['data_size'], [s/1024 for s in results['json_size']], 'o-', label='JSON')
        plt.plot(results['data_size'], [s/1024 for s in results['csv_size']], 'o-', label='CSV')
        plt.title('数据大小对比')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('数据大小 (KB)')
        plt.legend()
        plt.grid(True)

        # 性能比率
        plt.subplot(3, 2, 4)
        encode_ratio = [j/c for j, c in zip(results['json_encode_time'], results['csv_encode_time'])]
        decode_ratio = [j/c for j, c in zip(results['json_decode_time'], results['csv_decode_time'])]
        size_ratio = [j/c for j, c in zip(results['json_size'], results['csv_size'])]

        plt.plot(results['data_size'], encode_ratio, 'o-', label='编码时间比 (JSON/CSV)')
        plt.plot(results['data_size'], decode_ratio, 'o-', label='解码时间比 (JSON/CSV)')
        plt.plot(results['data_size'], size_ratio, 'o-', label='大小比 (JSON/CSV)')
        plt.axhline(y=1, color='r', linestyle='--')
        plt.title('JSON/CSV 性能比率')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('比率 (JSON/CSV)')
        plt.legend()
        plt.grid(True)

        # 编码+解码总时间
        plt.subplot(3, 2, 5)
        json_total = [e + d for e, d in zip(results['json_encode_time'], results['json_decode_time'])]
        csv_total = [e + d for e, d in zip(results['csv_encode_time'], results['csv_decode_time'])]
        plt.plot(results['data_size'], json_total, 'o-', label='JSON')
        plt.plot(results['data_size'], csv_total, 'o-', label='CSV')
        plt.title('总处理时间 (编码+解码)')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('总时间 (毫秒)')
        plt.legend()
        plt.grid(True)

        # 对数坐标下的性能
        plt.subplot(3, 2, 6)
        plt.loglog(results['data_size'], results['json_encode_time'], 'o-', label='JSON编码')
        plt.loglog(results['data_size'], results['csv_encode_time'], 'o-', label='CSV编码')
        plt.loglog(results['data_size'], results['json_decode_time'], 'o-', label='JSON解码')
        plt.loglog(results['data_size'], results['csv_decode_time'], 'o-', label='CSV解码')
        plt.title('性能随数据量变化 (对数坐标)')
        plt.xlabel('数据量 (记录数)')
        plt.ylabel('时间 (毫秒)')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.savefig('data_format_benchmark.png')
        plt.close()


if __name__ == "__main__":
    # 运行基准测试
    benchmark = DataFormatBenchmark()

    # 测试不同数据量
    data_sizes = [100, 500, 1000, 5000, 10000, 50000]
    results = benchmark.run_full_benchmark(data_sizes)

    # 输出总结
    print("测试总结:")
    print(f"数据量范围: {min(results['data_size'])} - {max(results['data_size'])}条记录")

    # 计算平均比率
    avg_encode_ratio = np.mean([j/c for j, c in zip(results['json_encode_time'], results['csv_encode_time'])])
    avg_decode_ratio = np.mean([j/c for j, c in zip(results['json_decode_time'], results['csv_decode_time'])])
    avg_size_ratio = np.mean([j/c for j, c in zip(results['json_size'], results['csv_size'])])

    print(f"编码时间比率 (JSON/CSV): {avg_encode_ratio:.2f}x")
    print(f"解码时间比率 (JSON/CSV): {avg_decode_ratio:.2f}x")
    print(f"数据大小比率 (JSON/CSV): {avg_size_ratio:.2f}x")

    print("\n测试完成,结果已保存到 data_format_benchmark.png")
这个性能测试方案全面比较了JSON(带key)和CSV(不带key)两种数据交换格式在不同数据量下的性能差异。主要测试内容包括: 1. 测试内容 测试方案包含三个主要方面的比较: - 编码性能:将数据结构转换为字符串的速度 - 解码性能:将字符串解析回数据结构的速度 - 数据大小:编码后的数据占用空间 2. 测试数据 - 生成不同数量的SQEP-BAR-DAY记录(从100到50,000条) - 每条记录包含完整的股票日线数据字段 - 数据内容模拟真实交易数据 3. 数据格式实现 1) JSON格式: - 包含完整的字段名(键) - 使用标准JSON结构,包含元数据和记录数组 - 示例:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
{
  "timestamp": "2023-05-01T12:00:00",
  "source": "benchmark",
  "data_type": "SQEP-BAR-DAY",
  "records": [
      {
      "symbol": "000001.SZ",
      "frame": "2023-05-01",
      "open": 10.5,
      "high": 11.2,
      "low": 10.3,
      "close": 11.0,
      "vol": 123456,
      "amount": 1234567,
      "adjust": 1.0
      },
      ...
  ]
  }
2) CSV格式: - 不包含字段名,依赖固定的字段顺序 - 第一行包含元数据(时间戳、来源、数据类型、记录数) - 后续每行是一条记录 - 示例:
1
2
3
2023-05-01T12:00:00,benchmark,SQEP-BAR-DAY,1000
000001.SZ,2023-05-01,10.5,11.2,10.3,11.0,123456,1234567,1.0
...
4. 测试指标 - 编码时间:将数据结构转换为字符串的平均时间(毫秒) - 解码时间:将字符串解析回数据结构的平均时间(毫秒) - 数据大小:编码后的字符串大小(字节/KB) - 性能比率:JSON/CSV的比值,表示相对性能差异

2.2. 测试结果

数据量(记录数) JSON编码时间(ms) CSV编码时间(ms) JSON解码时间(ms) CSV解码时间(ms) JSON大小(KB) CSV大小(KB) 编码比率(JSON/CSV) 解码比率(JSON/CSV) 大小比率(JSON/CSV)
100 0.13 0.14 0.07 0.09 16.16 7.12 0.96x 0.71x 2.27x
500 0.47 0.51 0.24 0.34 58.67 25.82 0.93x 0.69x 2.27x
1,000 1.42 1.55 0.75 0.94 160.47 70.67 0.91x 0.81x 2.27x
5,000 6.86 7.06 3.35 4.80 802.26 353.26 0.97x 0.70x 2.27x
10,000 13.23 13.94 6.69 9.67 1,602.63 705.58 0.95x 0.69x 2.27x
50,000 66.37 70.23 34.55 48.12 8,009.18 3,526.28 0.95x 0.72x 2.27x
平均 - - - - - - 0.95x 0.72x 2.27x

2.2.3. 结果分析
  1. 编码性能 JSON编码在所有测试数据量下都略快于CSV编码,平均快约5%(0.95x)
  2. 解码性能 JSON解码明显快于CSV解码,平均快约28%(0.72x)
  3. 数据大小 JSON格式的数据大小始终是CSV格式的2.27倍

总结

在这一章的探索中,我和我的AI助手007一起深入研究了量化交易系统中数据交换格式的性能问题。这次"数据格式大PK"不仅让我们获得了宝贵的技术数据,更展示了人机协作的无限可能!

007不愧是"码力全开"的得力助手,它不仅设计了全面的测试方案,还在遇到除零错误时迅速提供了解决方案。通过我们的共同努力,成功对比了JSON和CSV两种格式在不同数据量下的表现:JSON在处理速度上略胜一筹,而CSV在存储效率上更具优势。

这次测试不仅是技术上的突破,更是我们21天驯化AI打工仔挑战的又一个里程碑!正如测试数据一样,我们的合作也在不断扩展规模,从100条记录到50,000条记录,效率始终保持稳定,这正是我们合作的真实写照!

正如007所说:"数据是一切开始的基础",而我们的合作则是创新的源泉。期待在接下来的SQEP扩展格式探索中,继续与007携手并进,为量化交易系统注入更多智慧的火花!

下一步:SQEP扩展格式探索之旅 接下来,我和007将继续我们的冒险,探索SQEP的两个重要扩展格式: 1. SQEP-ST:专为特殊处理(ST)股票设计的数据格式 - 这些特殊股票信息虽然稀疏,但对投资决策至关重要 - 我们将巧妙地将ST信息整合到现有的SQEP-BAR-DAY表中 - 通过布尔型st字段,让系统能够快速识别特殊股票 2. 涨跌停信息:交易限制的关键指标 - 添加buy_limit和sell_limit字段,为回测系统提供精确的交易约束 - 这些信息将帮助我们模拟真实市场中的交易规则 - 确保回测结果更加贴近实际交易环境

有了这些扩展,我们的量化交易系统将更加完善,能够应对更复杂的市场情况。正如007所展示的那样,只要思路清晰、方法得当,即使是复杂的数据处理问题也能迎刃而解!让我们继续这场激动人心的21天挑战,用数据和智慧创造更多可能!

21天驯化AI打工仔 - 数据库的优化

五一小长假之前,我在搭档 007 的帮助下已经成功实现了从 Tushare 获取 OHLC 数据,并通过 Redis 消息队列将数据存储到 ClickHouse 数据库。为了进一步完善量化交易系统的数据支持,今天我们将聚焦于数据库的优化设计,主要涉及获取日线复权因子、获取分钟线数据以及计算分钟线数据(例如 15 分钟)三个方面。五一结束之后,我们计划通过这些优化,能够为量化交易策略提供更丰富、更精细的数据,从而提升系统的性能和决策能力。


1. 获取日线的复权因子

我:007!帮我在 data_producer.pydata_consumer.py 中添加获取日线因子的代码。

007:🫡收到请求,正在处理...

为了在 data_consumer.pydata_producer.py 中补充因子的数据,我们需要对原代码进行一些修改。主要思路是在生产者代码中获取因子数据并封装到数据结构中,然后在消费者代码中处理这些因子数据。因为后续我们要收集分钟级的数据,所以我新建了一张表,命名为 daily_data,用于存储日线因子数据。

然后,我需要针对这个数据表,对 data_producer.pydata_consumer.py 进行修改。

007 回答我的完整代码如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""data_producer.py"""
import redis
import tushare as ts
import json
from datetime import datetime

# Tushare和Redis配置
TUSHARE_TOKEN = "YOUR TOKEN"
REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_QUEUE_NAME = "tushare_data_queue"

# 初始化连接
pro = ts.pro_api(TUSHARE_TOKEN)
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)

def fetch_ohlc_daily_data(ts_code: str, start_date: str, end_date: str) -> list:
    try:
        df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
        return df[['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'vol']].to_dict('records')
    except Exception as e:
        print(f"OHLC数据获取失败: {str(e)}")
        return []

def fetch_adj_factor(ts_code: str, start_date: str, end_date: str) -> list:
    try:
        df = pro.adj_factor(ts_code=ts_code, start_date=start_date, end_date=end_date)
        return df[['ts_code', 'trade_date', 'adj_factor']].to_dict('records')
    except Exception as e:
        print(f"复权因子获取失败: {str(e)}")
        return []
def produce_data(ts_code_list: list, date_range: tuple):
    start_date, end_date = date_range
    for ts_code in ts_code_list:
        data_package = {
            "timestamp": datetime.now().isoformat(),
            "ts_code": ts_code,
            "ohlc_data": fetch_ohlc_daily_data(ts_code, start_date, end_date),
            "adj_factor": fetch_adj_factor(ts_code, start_date, end_date)
        }
        redis_client.lpush(REDIS_QUEUE_NAME, json.dumps(data_package))
        print(f"已推送数据: {ts_code} - {start_date}{end_date}")

if __name__ == "__main__":
    STOCK_CODES = ["000001.SZ", "600519.SH"]
    DATE_RANGE = ("20230101", "20231231")
    produce_data(STOCK_CODES, DATE_RANGE)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""data_consumer.py"""
import redis
import json
from clickhouse_driver import Client
from datetime import datetime

# 配置参数
REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_QUEUE_NAME = "tushare_data_queue"

CLICKHOUSE_HOST = "localhost"
CLICKHOUSE_PORT = 9000
CLICKHOUSE_DB = "default"

# 初始化 Redis 和 ClickHouse 客户端
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
clickhouse_client = Client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, database=CLICKHOUSE_DB)

def insert_to_clickhouse(data):
    """将数据插入到 ClickHouse"""
    query = """
    INSERT INTO daily_data (ts_code, trade_date, open, high, low, close, vol, adj_factor)
    VALUES
    """
    values = []
    for record in data["ohlc_data"]:
        adj_factor_record = next((adj for adj in data["adj_factor"] if adj["trade_date"] == record["trade_date"]), None)
        adj_factor = adj_factor_record["adj_factor"] if adj_factor_record else None
        values.append((
            record["ts_code"],
            datetime.strptime(record["trade_date"], "%Y%m%d").date(),
            record["open"],
            record["high"],
            record["low"],
            record["close"],
            record["vol"],
            adj_factor
        ))
    clickhouse_client.execute(query, values)

def consume_data():
    """数据消费主函数"""
    print("启动数据消费者,等待队列数据...")
    while True:
        try:
            # 阻塞式获取队列数据
            result = redis_client.brpop(REDIS_QUEUE_NAME, timeout=1)
            if result is None:
                # 如果没有获取到数据,说明队列为空,退出循环
                print("Redis 队列为空,停止消费数据。")
                break
            _, json_data = result
            data_package = json.loads(json_data)
            insert_to_clickhouse(data_package)
            print(f"成功插入数据: {len(data_package['ohlc_data'])} 条")
        except Exception as e:
            print(f"数据处理异常: {str(e)}")
            continue

if __name__ == "__main__":
    consume_data()

2. 获取分钟线数据

同上述步骤,我需要获取分钟线数据,并添加到数据表中。

007 给我提供了以下建议: 1. 首先创建分钟线数据表; 2. 修改生产者代码,添加分钟线数据获取功能; 3. 创建对应的消费者代码。

2.1. 创建分钟线数据表

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
CREATE TABLE IF NOT EXISTS minute_data (
    ts_code String,
    trade_time DateTime,
    open Float32,
    high Float32,
    low Float32,
    close Float32,
    vol Float32,
    amount Float32
) ENGINE = MergeTree()
ORDER BY (ts_code, trade_time);

2.2. 修改生产者代码,添加分钟线数据获取功能

007 还是采用 tushare 来获取分钟线数据,并添加到数据表中。但是,这里获取分钟线数据我并不打算用 tushare 作为数据源,而是打算采用 qmt 提供的 API 接口来获取分钟级的数据。

007 添加了一个新的生产者代码文件 minute_producer.py,在 data_producer.py 的基础上主要修改内容: - 替换了 Tushare 为 QMT 的数据接口 - 添加了数据下载功能 - 使用 QMT 的 get_local_data 获取分钟线数据 - 增加了交易日期的处理

但是我在此处遇到了一个问题,qmt 目前只支持 windows 系统,而我的系统是 macos ,所以无法使用 qmt。

007 提供了一个解决方案,使用 redis 作为中间件,把 windows 的数据传递给 macos 的程序,并最终存储在 clickhouse 中。

依据 007 的好建议,我完成了以下代码。

2.2.3. windows 的数据生产者代码

007 原先的代码如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import redis
import json
from datetime import datetime
from xtquant.xtdata import (
    init,
    download_history_data,
    get_local_data,
    get_trading_dates,
    close
)

# Redis配置 - 使用Mac的IP地址
REDIS_HOST = "替换为Mac的IP地址"
REDIS_PORT = 6379
REDIS_QUEUE_NAME = "qmt_minute_queue"
REDIS_PASSWORD = None  # 如果有密码,请设置

def setup_redis_client():
    """初始化Redis客户端"""
    return redis.StrictRedis(
        host=REDIS_HOST, 
        port=REDIS_PORT, 
        password=REDIS_PASSWORD,
        decode_responses=True
    )

def fetch_minute_data(stock_code, date_str):
    """获取指定日期的分钟线数据"""
    try:
        # 获取分钟线数据
        df = get_local_data(stock_code, 'min1', date_str, date_str)
        if df is None or len(df) == 0:
            return []

        # 转换数据格式
        records = []
        for time, row in df.iterrows():
            records.append({
                "ts_code": stock_code,
                "trade_time": time.strftime("%Y-%m-%d %H:%M:%S"),
                "open": float(row['open']),
                "high": float(row['high']),
                "low": float(row['low']),
                "close": float(row['close']),
                "vol": float(row['volume']),
                "amount": float(row['amount'])
            })
        return records
    except Exception as e:
        print(f"获取分钟线数据失败 {stock_code} {date_str}: {str(e)}")
        return []

def main():
    """主函数"""
    # 初始化QMT接口
    init()

    # 初始化Redis客户端
    redis_client = setup_redis_client()

    try:
        # 配置参数
        stock_list = ["000001.SZ", "600519.SH"]
        start_date = "20230101"
        end_date = "20230131"

        # 下载历史数据
        print(f"开始下载历史数据: {start_date}{end_date}")
        download_history_data(stock_list, 'min1', start_date, end_date)
        print("历史数据下载完成")

        # 获取交易日列表
        trading_dates = get_trading_dates(start_date, end_date)

        # 按日期和股票代码获取分钟线数据并发送到Redis
        for trade_date in trading_dates:
            date_str = trade_date.strftime("%Y%m%d")
            print(f"处理日期: {date_str}")

            for stock_code in stock_list:
                minute_data = fetch_minute_data(stock_code, date_str)

                if minute_data:
                    # 封装数据
                    data_package = {
                        "timestamp": datetime.now().isoformat(),
                        "ts_code": stock_code,
                        "trade_date": date_str,
                        "minute_data": minute_data
                    }

                    # 发送到Redis
                    redis_client.lpush(REDIS_QUEUE_NAME, json.dump(data_package))
                    print(f"已推送分钟线数据: {stock_code} - {date_str} ({len(minute_data)}条)")

    except Exception as e:
        print(f"程序执行异常: {str(e)}")

    finally:
        # 关闭QMT接口
        close()
        print("程序执行完毕")

if __name__ == "__main__":
    main()

但是这个代码是不可以运行的,因为 qmt 库的版本可能发生了改变,有一些模块是被移除或修改的。而且我们想从 Windows 机器连接到 Mac 上运行的 Redis 服务器。这涉及到网络连接、防火墙设置、Redis 配置以及可能的权限问题。

针对 Redis 的有关问题,我打算继续用 007 的“贴心代码”进行 Redis 连接的测试。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Redis配置 - 使用Mac的IP地址
REDIS_HOST = "替换为Mac的IP地址"
REDIS_PORT = 6379
REDIS_QUEUE_NAME = "qmt_minute_queue"
REDIS_PASSWORD = None  # 如果有密码,请设置


# 测试Redis连接
import redis
import time

try:
    # 创建Redis客户端
    redis_client = redis.Redis(
        host=REDIS_HOST,
        port=REDIS_PORT,
        password=REDIS_PASSWORD,
        socket_timeout=5,
        decode_responses=True
    )

    # 测试连接 - PING命令
    response = redis_client.ping()
    print(f"Redis连接测试 (PING): {'成功' if response else '失败'}")

    # 测试基本操作 - 写入和读取
    test_key = "test_connection_key"
    test_value = f"test_value_{time.time()}"

    # 写入测试
    redis_client.set(test_key, test_value)
    print(f"Redis写入测试: 成功写入键 '{test_key}'")

    # 读取测试
    read_value = redis_client.get(test_key)
    print(f"Redis读取测试: {'成功' if read_value == test_value else '失败'}")
    print(f"写入值: {test_value}")
    print(f"读取值: {read_value}")

    # 测试队列操作
    redis_client.lpush(REDIS_QUEUE_NAME, "测试消息")
    queue_length = redis_client.llen(REDIS_QUEUE_NAME)
    print(f"Redis队列测试: 成功写入队列 '{REDIS_QUEUE_NAME}',当前队列长度: {queue_length}")

    # 清理测试数据
    redis_client.delete(test_key)

    print("Redis连接和基本操作测试完成,连接正常")

except redis.exceptions.ConnectionError as e:
    print(f"Redis连接错误: {str(e)}")
    print("请检查以下问题:")
    print("1. Redis服务器是否在运行")
    print("2. IP地址是否正确")
    print("3. 端口是否正确")
    print("4. 防火墙是否允许连接")
    print("5. Redis是否配置为允许远程连接")

except Exception as e:
    print(f"Redis测试过程中出现其他错误: {str(e)}")

有了测试代码,我们还需要考虑 Redis 的默认配置。默认情况下,Redis 只允许本地连接,绑定到 127.0.0.1。因此,要允许远程连接,必须修改 Redis 的配置文件,将绑定地址改为 0.0.0.0 或者 Mac 的局域网 IP 地址。这一步可能需要用户编辑 Redis 的配置文件,比如 redis.conf,找到 bind 参数并修改。

  1. 修改 Redis 配置文件 默认情况下,Redis 仅监听本地 IP(127.0.0.1),需调整为允许远程连接:

    • 打开配置文件
      1
      sudo nano /usr/local/etc/redis.conf
      
    • 修改以下参数:
    • 绑定 IP:将 bind 127.0.0.1 改为 bind 0.0.0.0(允许所有 IP 访问)或替换为 Mac 的局域网 IP(如 192.168.1.100)。

    • 关闭保护模式:将 protected-mode yes 改为 protected-mode no

    • 设置密码(可选但推荐):取消注释 requirepass 并设置密码:
      1
      requirepass your_password
      
    • 保存并退出:按 Ctrl+O 保存,Ctrl+X 退出。
  2. 重启 Redis 服务(二选一即可)

    1
    2
    brew services restart redis  # 适用于Homebrew安装
    redis-server /usr/local/etc/redis.conf  # 手动重启
    

  3. 开放 Mac 防火墙端口

    • 图形化操作:
    • 进入 系统偏好设置 → 安全性与隐私 → 防火墙。
    • 点击锁图标解锁,选择 防火墙选项。
    • 点击 + 添加 redis-server 到允许列表。
    • 命令行操作(需管理员权限):
      1
      2
      sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add /usr/local/bin/redis-server
      sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblockapp /usr/local/bin/redis-server
      

完成上述步骤之后,我们对 Windows 端进行配置(下载 Redis 客户端) 1. 安装 Redis 客户端 - 下载 Windows 版 Redis:从 Redis 官网 下载稳定版,解压到任意目录(如 C:\redis)。 - 添加到系统路径:将 C:\redis\bin 加入环境变量 PATH,以便在命令行直接使用 redis-cli。

  1. 连接 Redis 服务器
    • 命令格式:
      1
      redis-cli -h <Mac的局域网IP> -p 6379 -a <密码>
      
    • 示例(假设 Mac IP 为 192.168.1.100,密码为 your_redis_password):
      1
      redis-cli -h 192.168.1.100 -p 6379 -a your_redis_password
      
    • 验证连接:
      1
      2
      192.168.1.100:6379> PING
      PONG  # 连接成功
      

上面的步骤依旧不能确保 Redis 服务器在 Windows 上可以接受来自 Mac 的连接。(这里我们先留一个坑如果朋友们有想法也可以在评论区留下你们的解决方案)

2.2.4. macos 的数据消费者代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
 import redis
 import json
 from clickhouse_driver import Client
 from datetime import datetime
 import time

 # Redis配置
 REDIS_HOST = "localhost"  # 本地Redis或Windows的IP
 REDIS_PORT = 6379
 REDIS_QUEUE_NAME = "qmt_minute_queue"
 REDIS_PASSWORD = None  # 如果有密码,请设置

 # ClickHouse配置
 CLICKHOUSE_HOST = "localhost"
 CLICKHOUSE_PORT = 9000
 CLICKHOUSE_DB = "default"
 CLICKHOUSE_USER = "default"
 CLICKHOUSE_PASSWORD = ""

 def setup_redis_client():
     """初始化Redis客户端"""
     return redis.StrictRedis(
         host=REDIS_HOST, 
         port=REDIS_PORT, 
         password=REDIS_PASSWORD,
         decode_responses=True
     )

 def setup_clickhouse_client():
     """初始化ClickHouse客户端"""
     return Client(
         host=CLICKHOUSE_HOST,
         port=CLICKHOUSE_PORT,
         database=CLICKHOUSE_DB,
         user=CLICKHOUSE_USER,
         password=CLICKHOUSE_PASSWORD
     )

 def insert_to_clickhouse(client, data):
     """将分钟线数据插入到ClickHouse"""
     if not data["minute_data"]:
         print("没有数据需要插入")
         return 0

     query = """
     INSERT INTO minute_data 
     (ts_code, trade_time, open, high, low, close, vol, amount)
     VALUES
     """

     values = []
     for record in data["minute_data"]:
         values.append((
             record["ts_code"],
             datetime.strptime(record["trade_time"], "%Y-%m-%d %H:%M:%S"),
             record["open"],
             record["high"],
             record["low"],
             record["close"],
             record["vol"],
             record["amount"]
         ))

     if values:
         client.execute(query, values)
         return len(values)
     return 0

 def main():
     """主函数"""
     # 初始化客户端
     redis_client = setup_redis_client()
     clickhouse_client = setup_clickhouse_client()

     print("启动分钟线数据消费者,等待队列数据...")

     try:
         while True:
             # 尝试从Redis获取数据
             result = redis_client.brpop(REDIS_QUEUE_NAME, timeout=1)

             if result is None:
                 print("Redis队列为空,等待新数据...")
                 time.sleep(5)  # 等待5秒再次尝试
                 continue

             # 解析数据
             _, json_data = result
             data_package = json.loads(json_data)

             # 插入ClickHouse
             inserted_count = insert_to_clickhouse(clickhouse_client, data_package)
             print(f"成功插入分钟线数据: {data_package['ts_code']} - {data_package['trade_date']} ({inserted_count}条)")

     except KeyboardInterrupt:
         print("程序被手动中断")

     except Exception as e:
         print(f"程序执行异常: {str(e)}")

     finally:
         print("程序执行完毕")

 if __name__ == "__main__":
     main()

由于windows和mac之间的redis连接问题还尚未解决,我打算先把从 qmt 去到的分钟线数据存入 000001.SH_data.csv300750.SZ_data.csv,将其中的数据存入Clickhouse数据库中,并使用Python进行数据处理。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from xtquant import xtdata
import os
import pandas as pd

code_list = ['000001.SH', '300750.SZ']
period = '1h'
start_time = '20250101093000'
end_time = '20250201093000'

def on_data(datas):
    if datas:
        print(datas)
    else:
        print("数据下载失败或为空")

xtdata.download_history_data2(code_list, period, start_time, end_time, on_data)

# 创建目录(如果不存在)
save_dir = 'C:\\wbq'
os.makedirs(save_dir, exist_ok=True)

for code in code_list:
    data = xtdata.get_market_data_ex([], [code], period, start_time, end_time)
    if code in data and not data[code].empty:
        # 为每个股票创建单独的文件
        file_path = os.path.join(save_dir, f'{code}_data.csv')
        # 确保数据是DataFrame格式
        df = data[code]
        # 保存数据,添加错误处理
        try:
            df.to_csv(file_path)
            print(f'{code}数据保存到本地: {file_path}')
        except Exception as e:
            print(f'保存{code}数据时出错: {str(e)}')
    else:
        print(f'{code}没有获取到数据')

print("运行结束")

21天驯化AI打工仔 - 开发量化交易系统

今天是第二天,我计划实现如下任务: 1. 安装 ClickHouse 和 DBeaver 2. 创建 ClickHouse 数据表 3. 修改 Redis 消息队列代码以支持 ClickHouse 存储

我唤醒了 007,它今天是要陪我一起战斗代码的。


1. 安装 ClickHouse 和 DBeaver

我使用的是 mac 系统,电脑里还没有安装 ClickHouse 和 DBeaver,所以我首先需要安装它们。

1.1. 安装 ClickHouse

我:007!给我一个ClickHouse mac 系统的安装教程。

007:收到!以下是您需要的 ClickHouse 的安装教程!

ClickHouse 是一个用于联机分析(OLAP)的列式数据库管理系统(DBMS),最早由俄罗斯 IT 公司 Yandex 为 Yandex.Metrica 网络分析服务开发,于 2016 年基于宽松式的 Apache2.0 许可证发布。

于是,我使唤 007 给我一个 Mac 系统安装 ClickHouse 的教程,它对此表示非常容易,于是,我按照下面的步骤进行安装。

  1. 下载 ClickHouse:前往 ClickHouse 官方网站 下载适用于 Mac 系统的安装包。也可以使用 curl 实现快速安装: curl https://clickhouse.com/ | sh
  2. 可以通过 Homebrew 安装 ClickHouse: brew install clickhouse
  3. 启动服务器:打开终端,导航到解压后的目录,执行 ./clickhouse - server 命令启动 ClickHouse 服务器。若要在后台启动服务器,可使用 ./clickhouse - server --daemon 命令。服务器启动后,会监听默认端口 9000。
  4. 启动客户端:打开终端,导航到解压后的目录,执行 ./clickhouse - client 命令启动 ClickHouse 客户端。若要在后台启动客户端,可使用 ./clickhouse - client --daemon 命令。客户端启动后,你可以使用 SQL 语句与 ClickHouse 进行交互。

1.2. 安装 DBeaver

为了方便管理 ClickHouse 数据库,我计划安装 DBeaver。

我:007!再给我一个DBeaver mac 系统的安装教程。

007:收到!以下是您需要的 DBeaver 的安装教程!

DBeaver 是一个通用数据库管理工具和 SQL 客户端,支持多种数据库,其中包括 ClickHouse。通过 DBeaver,用户可以使用 JDBC 驱动与 ClickHouse 建立连接,从而实现对 ClickHouse 数据库的访问和操作。用户只需在 DBeaver 中配置 ClickHouse 的连接参数,如主机地址、端口、用户名和密码等,就可以方便地连接到 ClickHouse 数据库,进行各种数据库操作。

  1. 安装 DBeaver 并连接 ClickHouse
  2. 下载并安装 DBeaver。
  3. 打开 DBeaver,点击 Database > New Database Connection
  4. 搜索并选择 ClickHouse,填写连接信息(默认端口为 9000)。
  5. 测试连接,确保成功。

2. 创建 ClickHouse 数据表

我:007!给我一个 ClickHouse 表,用于存储 OHLC 数据。

007:收到!在 DBeaver 中,执行以下 SQL 创建存储 OHLC 数据的表。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
CREATE TABLE ohlc_data (
    ts_code String,
    trade_date Date,
    open Float32,
    high Float32,
    low Float32,
    close Float32,
    vol Float32
) ENGINE = MergeTree()
ORDER BY (ts_code, trade_date);

按照 007 的指示,我创建了一个名为 ohlc_data 的表,用于存储 OHLC 数据。

3. 修改 Redis 消息队列代码以支持 ClickHouse 存储

我现在安装好了所需要的数据库软件,在第一天到基础上,我需要更新 data_consumer.py 代码,将数据存储到 ClickHouse 数据库。007 听令!

007 说这个任务它承包了,于是给了我第二版的data_consumer.py 代码。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import redis
import json
from clickhouse_driver import Client

# 配置参数
REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_QUEUE_NAME = "tushare_data_queue"

CLICKHOUSE_HOST = "localhost"
CLICKHOUSE_PORT = 9000
CLICKHOUSE_DB = "default"

# 初始化 Redis 和 ClickHouse 客户端
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
clickhouse_client = Client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, database=CLICKHOUSE_DB)

def insert_to_clickhouse(data):
    """将数据插入到 ClickHouse"""
    query = """
    INSERT INTO ohlc_data (ts_code, trade_date, open, high, low, close, vol)
    VALUES
    """
    values = [
        (
            record["ts_code"],
            record["trade_date"],
            record["open"],
            record["high"],
            record["low"],
            record["close"],
            record["vol"]
        )
        for record in data
    ]
    clickhouse_client.execute(query, values)

def consume_data():
    """数据消费主函数"""
    print("启动数据消费者,等待队列数据...")
    while True:
        try:
            # 阻塞式获取队列数据
            _, json_data = redis_client.brpop(REDIS_QUEUE_NAME, timeout=30)
            if json_data:
                data_package = json.loads(json_data)
                ohlc_data = data_package["ohlc_data"]
                insert_to_clickhouse(ohlc_data)
                print(f"成功插入数据: {len(ohlc_data)} 条")
        except Exception as e:
            print(f"数据处理异常: {str(e)}")
            continue

if __name__ == "__main__":
    consume_data()

我尝试运行 data_consumer.py,但程序报错如下:

1
2
3
4
(course) quantide@Mac note % python data_consumer.py 
启动数据消费者,等待队列数据... 
数据处理异常: 'str' object has no attribute 'year'
数据处理异常: 'str' object has no attribute 'year'

007(委屈但诚恳):根据错误信息 数据处理异常: 'str' object has no attribute 'year',问题可能出在 trade_date 字段的处理上。trade_date 是一个字符串(例如 "20230101"),而 ClickHouse 需要 Date 类型的数据。在插入数据时,ClickHouse 的 Date 类型需要一个有效的日期格式(如 YYYY-MM-DD)。

我们需要将 trade_date 从字符串格式转换为 YYYY-MM-DD 格式:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from datetime import datetime

# ...  源代码见上一章节。

def insert_to_clickhouse(data):
    """将数据插入到 ClickHouse"""
    query = """
    INSERT INTO ohlc_data (ts_code, trade_date, open, high, low, close, vol)
    VALUES
    """
    values = [
        (
            record["ts_code"],
            # 将 trade_date 字符串转换为日期类型
            datetime.strptime(record["trade_date"], "%Y%m%d").date(),
            record["open"],
            record["high"],
            record["low"],
            record["close"],
            record["vol"]
        )
        for record in data
    ]
    clickhouse_client.execute(query, values)

在此基础上,我希望让 data_consumer.py 在将 Redis 里的数据传送到 ClickHouse 之后停止运行,可在 Redis 队列没有数据时,终止循环。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
# ...  源代码见上一章节。

def consume_data():
    """数据消费主函数"""
    print("启动数据消费者,等待队列数据...")
    while True:
        try:
            # 阻塞式获取队列数据
            result = redis_client.brpop(REDIS_QUEUE_NAME, timeout=1)
            if result is None:
                # 如果没有获取到数据,说明队列为空,退出循环
                print("Redis 队列为空,停止消费数据。")
                break
            _, json_data = result
            data_package = json.loads(json_data)
            ohlc_data = data_package["ohlc_data"]
            insert_to_clickhouse(ohlc_data)
            print(f"成功插入数据: {len(ohlc_data)} 条")
        except Exception as e:
            print(f"数据处理异常: {str(e)}")
            continue

007 在修改后,还“贴心”地给出了代码解释: - brpop 函数的 timeout 参数:把 brpop 函数的 timeout 参数设定为 1 秒,这意味着如果在 1 秒内没有从队列中获取到数据,brpop 函数就会返回 None。 - 队列为空的判断:在 while 循环里,检查 brpop 函数的返回值。若返回值为 None,就表明队列是空的,此时使用 break 语句退出循环,从而停止程序的运行。

这样,当 Redis 队列中的数据全部被处理完之后,程序就会自动停止。

经过上述修改,data_consumer.py 现在可以成功将数据从 Redis 队列传送到 ClickHouse 数据库。

21天驯化AI打工仔 - 我如何获取量化数据

IDEA:本人和本人的 AI黑奴 的相互协作,能不能在短短 21 天内开发出一套量化交易系统?

这么有意思的挑战,不如就从今天开始吧!“数据是一切开始的基础”,我打算先安排 AI黑奴 从数据获取开始做起。(感觉叫 AI黑奴 不太好听,那就给它取个名字叫:007号打码机,希望007号“牛码”可以“码力全开”)

好!下面我们正式准备开发工作!


Redis 启动!

我希望开发一个量化交易框架,并且使用 Redis 构建消息队列,从 Tushare 获取 OHLC 数据和复权因子。

首先,考虑代码的结构。需要分为两个主要部分:数据获取生产者和消费者处理数据。生产者从 Tushare 获取数据,然后将数据放入 Redis 队列;消费者从队列中取出数据进行处理。需要处理 Tushare 的 API 调用,包括获取 OHLC 和复权因子,这里可能需要使用 tushare 的 pro_api,所以需要用户提供 token。

接下来,Redis 部分,使用 redis-py 库,生产者使用 lpush 将数据推入队列,消费者使用 brpop 阻塞式获取数据。需要注意异常处理,比如 Tushare 的 API 调用可能失败,Redis 连接问题等,所以需要添加 try-except 块。

然后,考虑代码的独立性和可运行性。需要导入必要的库,如 redis、tushare、json 等。生产者函数中,获取数据后转换为 JSON 格式再推入队列。消费者函数中,取出数据后解析并打印,这里可以扩展为实际的处理逻辑,比如存储到数据库或进一步计算。

还要注意配置部分,将 Redis 的连接参数和 Tushare 的 token 作为配置项,方便用户修改。代码中使用注释说明需要用户替换的部分,比如 tushare_token 和 redis 的 host、port 等。

最后,确保 doubaocanvas 的格式正确,identifier 使用 “quant-data-redis-queue”,type 为 general_code,title 描述清楚。两个代码部分分别作为生产者和消费者的脚本,path 分别为 producer.py 和 consumer.py,description 准确。检查是否有未定义的变量,确保代码完整,可以独立运行,用户只需替换配置即可使用。

我的007表示这很简单,瞬间根据我的要求,给了两份代码文件:data_producer.pydata_consumer.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""data_consumer.py"""
import redis
import json
from typing import Dict

# 配置参数(需与生产者保持一致)
REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_QUEUE_NAME = "tushare_data_queue"

redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)

def process_data(data_package: Dict):
    """数据处理逻辑(可扩展为实际业务处理)"""
    print(f"接收到数据时间: {data_package['timestamp']}")
    print(f"处理股票代码: {data_package['ts_code']}")
    print(f"OHLC数据条目数: {len(data_package['ohlc_data'])}")
    print(f"复权因子条目数: {len(data_package['adj_factor'])}")
    # 此处可添加数据存储/计算/分发等实际业务逻辑

def consume_data():
    """数据消费主函数(阻塞式获取)"""
    print("启动数据消费者,等待队列数据...")
    while True:
        try:
            # 阻塞式获取队列数据(超时时间30秒)
            _, json_data = redis_client.brpop(REDIS_QUEUE_NAME, timeout=30)
            data_package = json.loads(json_data)
            process_data(data_package)

        except Exception as e:
            print(f"数据处理异常: {str(e)}")
            continue

if __name__ == "__main__":
    consume_data()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""data_producer.py"""
import redis
import tushare as ts
import json
from datetime import datetime

# 配置参数
TUSHARE_TOKEN = "87a5a442931d15e749b4453b4f802c3b87794c3a5968b7b40a153edd"  # 替换为你的Tushare token
REDIS_HOST = "localhost"
REDIS_PORT = 6379
REDIS_QUEUE_NAME = "tushare_data_queue"

# 初始化连接
pro = ts.pro_api(TUSHARE_TOKEN)
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)

def fetch_ohlc_data(ts_code: str, start_date: str, end_date: str) -> list:
    """获取OHLC数据"""
    try:
        df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
        return df[['ts_code', 'trade_date', 'open', 'high', 'low', 'close', 'vol']].to_dict('records')
    except Exception as e:
        print(f"OHLC数据获取失败: {str(e)}")
        return []

def fetch_adj_factor(ts_code: str, start_date: str, end_date: str) -> list:
    """获取复权因子"""
    try:
        df = pro.adj_factor(ts_code=ts_code, start_date=start_date, end_date=end_date)
        return df[['ts_code', 'trade_date', 'adj_factor']].to_dict('records')
    except Exception as e:
        print(f"复权因子获取失败: {str(e)}")
        return []

def produce_data(ts_code_list: list, date_range: tuple):
    """数据生产主函数"""
    start_date, end_date = date_range
    for ts_code in ts_code_list:
        # 封装数据结构
        data_package = {
            "timestamp": datetime.now().isoformat(),
            "ts_code": ts_code,
            "ohlc_data": fetch_ohlc_data(ts_code, start_date, end_date),
            "adj_factor": fetch_adj_factor(ts_code, start_date, end_date)
        }

        # 推入Redis队列
        redis_client.lpush(REDIS_QUEUE_NAME, json.dumps(data_package))
        print(f"已推送数据: {ts_code} - {start_date}{end_date}")

if __name__ == "__main__":
    # 示例参数
    STOCK_CODES = ["000001.SZ", "600519.SH"]  # 沪深300成分股示例
    DATE_RANGE = ("20230101", "20231231")   # 年度数据范围

    produce_data(STOCK_CODES, DATE_RANGE)

运行了两个文件没有反应,原来是这个电脑还没有下载并启动 Redis 。然后,我快速手动下载并启动,并夸赞了 007 贴心给我写了一段检查 Redis 是否正常启动的代码。

此时,运行 007 的“贴心代码”,就显示成功连接到 Redis:

为了直观地查看和管理 Redis 中的数据,我安装了 RedisInsight

Tushare 获取数据

完成了 Redis 的连接,直接来运行 007 提供的 data_producer.py,成功在 RedisInsight 中看到了数据:

再运行 data_consumer.py,成功从消息队列中取出数据。007 还为我讲解了一下 producer 和 consumer 之间的关系,今天又是依赖 AI 打工仔的一天呢~

试过 Cursor 和 Trae 之后,我如何用 Augment 完成了一个复杂项目

常常有人问,真有人用 AI 完成过一个复杂的项目吗?

我!

在这个过程中,我感受到 Augment (也许不只是 Augment,而是 AI 辅助编程)强大的力量。它帮我省下很多个小时。如果你是一位秀发飘逸的美女程序员,你更是应该用它 -- 它指定能保住你的头发 -- 不过这一点对我来说已经无关紧要了。


为了完成匡醍量化课程的学员注册问题,最近我尝试了Cursor和 Trae,最终使用 Augment 来完成了这个项目。当然,我不想在这里争论类似于 PHP 是不是最好的编程语言的问题;所以,如果你对 Augment 不感兴趣,那也请多尝试其它 AI 编程工具!

今天这个令人疲倦、但最终峰回路转的下午,是促使我再加班一晚上写这篇文章的原因。Augment 的 Agent Auto 模式成为 Man of the Moment, 最终它把我拉出了泥潭,使得我可以赶上项目部署的 deadline,并且还可能有时间水一篇文章。

在我最近招聘员工的时候,我会抛出这样一张图,让候选人回答,你能从这张图中,得到哪些信息?

如果你应聘职位是 HR,那么很显然你会优先联系当前在线、新牛人这位,同时也要考虑职位匹配程度。但这是一道考察观察和归纳能力的试题:这些小卡片提供了哪些类别的信息?

在工作中的无数场合,我们都需要这样的能力。比如,如果你在做小红书运营,你常常需要快速从两个相似但浏览量迥异的笔记中,找出流量密码,就需要这样的眼力。

其实这样的能力程序员也非常需要。今天我自己就被将了一军。最终问题定位出来,表现为下图的差异:

这是我们两门课程的课程目录。它是用 JupyterLab 构建的一个课程环境,当用户点击其中一堂课程时,根据他的权限和购买时长,链接可能打得开也可能打不开。现在的问题是,左边的该打得开的,全都能打开;右边的则全部打不开。

当我们找到最终问题后,把两个目录并排放在一起的时候,可能差异会一目了然。但在编程环境下,不能访问的链接,是以这样的方式出现的:

不能访问

明明文件就在那儿,没有任何错误,却不能访问。而能访问的链接,还会出一些奇怪的错误来干扰你。

要了解这个问题为什么复杂,我们得先介绍下系统的架构和技术栈。这也是为什么说我们用 Augment 构建了一个复杂项目的原因。通过对比我们的项目,你就会知道,在现在的 AI 能力下,你能构建的项目,至少能复杂到什么程度,这对你决策多大程度投入到AI中会有一定帮助。

系统架构

这是一个简化了的系统架构图。限于 AI 生成 mermaid 图的能力,它跟我们实际的架构有一些区别:

  1. nginx 与 course container 在同一个 docker 网络。nginx 暴露了主机端口。
  2. nginx 要能访问主机上的 provision server. Provisioin server 只部署在主机中,因为它需要为新学员创建容器。
  3. 系统中有很多 course container。Nginx 收到浏览器传来的类似 courseware/01.ipynb 这样的请求之后,需要动态地分发到每个学员对应的 course container。该 container 运行在 JupyterLab 的服务。
  4. 使用了前后端分离式设计,并且有两个前端 SPA,一个由 nginx host,另一个由 provision server 来 host。这会存在两个 front 的目录,而且要和后端 (python) 放在同一个项目中,这种架构并不常见,也为项目开发增加了许多难度。
  5. 实际部署中还涉及更复杂的云上网络,细节就不透露了。

为了限制未登录用户的访问,每一个请求,都要由 nginx 发给 provision server 来认证。所以,有时候出现前面的不能访问的图,会是正常行为。

这个系统的技术栈和技术要求如下:

  1. 使用 vue3 和 vite 来构建前端。前端尽可能做到响应式。
  2. 由于创建容器的时间可能较长,前端的 admin 界面需要与后端有 web socket 通信,把后端容器的构建状态实时传递给前端。这里还用到了多线程。
  3. 数据库使用文件数据库就好。一般应该选择 sqlite,不过我更熟悉和喜欢 postgreSQL 的语法,因此使用 duckdb。
  4. Provision 服务器使用了 Blacksheep 来构建,它性能卓越,同时接口很人性化。但是,社区没有 flask 等成熟。最终这个选择让我付出一定的代价。大概有一天多时间用在解决 blacksheep 自身面临的如何 host SPA 程序的问题。这里还有一个有趣的小故事。
  5. Blacksheep 不能独立运行,它必须借助 uvicorn 来启动。uvicorn 也很不成熟,它贡献了一个小 bug,AI 也不清楚如何解决,查了 github 才找到解决方案。这额外耽误了一天时间。
  6. 使用了 nginx 的 auth_request 模块,来认证用户。最初还尝试了 openlightspeed,最后撤退到社区更成熟的 nginx 上来。
  7. 我用的是 mac 开发,最终部署到 ubuntu 上。在 mac 下,我使用 orbstack 来运行容器。它和 nginx 官方容器一起,贡献了一个日志目录映射的 bug。

你们会觉得这个系统复杂吗?我觉得可以算是。因为它用了两种编程语言,还涉及到 docker 和 nginx 一些改写规则,所以,实际上涉及到三名工程师:前端工程师 (Vuejs)、后端工程师 (Python) 和运维工程师的活。

反正,我花 299 报名学的 AI 编程是这样的:

完成这样的项目,只要一个前端就可以了。而且它只有两种固定的界面版式,也没有考虑到PC端。

令人疲惫的下午

今天下午,我正在为新的课程系统上线,做最后的测试。这个系统本来应该像这样工作,这是添加客户:

朴素的后台

在增加用户之后,就可以为他注册课程,生成专属环境:

创建容器

然后,学员就可以在浏览器中登录:

登录

登录之后,就可以看到他所报名的课程:

用户点击粉红色的课程链接,就应该看到该课的内容。这个课程的链接是:

1
http://*/course/l24/quantide/lab/tree/courseware/01.ipynb

点击之后,本来应该能打开 01.ipynb 这个 notebook,但实际上返回了不能访问的错误。此时会进入到 jupyter lab 的 home 界面:

home

此时清楚地看到所有的 notebook 都已经列出来了,但是当你点击它们时,仍然是不能访问。

由于这个问题只在《量化 24 课》中存在,在《因子分析与机器学习策略》课程中不存在,所以,我肯定这只是一个部署和配置问题,于是,就只开启了 Augment 的 chat 模式来帮我。

它指导了我一下午。我们查过两个容器的 entrypoint 脚本是否一样,目录映射语法是否一样,jupyter_lab_config 是否一样,容器的环境变量是否一样,Augment 还多次兴奋地大喊,我找到问题了!

在无数次修改、重启和回滚之后,我想明白一件事,作为出题人,如果我自己来看这张图:

我就能一眼看出上下两个卡片的所有区别吗?我会不会进入到某种常识性的盲区?比如,124 和 l24 用肉眼是很难看出来的,天知道有多少这种细微的差别!

于是,我决定启用 Augment 的 Agent (Auto) 模式。实际上这些天的大多数时间,我是在用 Agent Auto 来帮我 coding 的。

Tip

Augment 有三种模式,chat, agent 和 agent auto 模式。Chat 模式一样能读全部代码,写代码(和修改),但它不能调用工具;Agent 模式可以做 Chat 能做的事,并且可以调用工具,但在调用工具之前,需要得到你的批准;Agent Auto 模式则非常高效,它会自动调用工具,直到最后告诉你,很好!我已经完成了这个任务。

但是,我不想让 Agent Auto 在最后关头给我任何惊喜,所以,我告诉他,

Attention

在我们正在工作的项目中,存在这样一个问题。当我点击 academy > 我的课程面板 > 24 课的链接时出现无法访问;但 fa 课程的链接就可以访问。两者都是经过同一个 nginx 容器转发;nginx 后面的容器不一样,一个是 course_l24_quantide,一个是 course_fa_quantide.

我启用 agent 模式,不是让你进行改动,而是让你可以直接运行一些命令来查看文件和(包括容器)状态。

现在,请着手调查这个问题。

Aha 时刻

Agent 以常规的姿势开始,先是响应我的问题。但我的问题说得比较笼统,所以,它的响应也很简单:

1
I'll help you investigate why the 24 课 links aren't working while the FA course links are working. Let's start by gathering information about the configuration and status of the relevant components.

老实说,有点没太指望了。

接下来,它阅读了 containers.yaml -- 这是用来创建容器的配置文件。

然后它调用 docker ps 检查哪些容器正在运行。它了解到 nginx 容器很重要,于是,把 nginx 看了个底掉:

1
2
3
4
docker exec nginx ls -la /etc/nginx/
docker exec nginx cat /etc/nginx/auth.conf
docker exec nginx cat /var/log/nginx/error.log | tail -n 50
docker exec nginx cat /var/log/nginx/access.log | tail -n 20

在调查阶段,它没有出任何结论,而是接着看另外两个容器:

1
2
docker exec course_l24_quantide ls -la /home/quantide/notebooks/courseware
docker exec course_fa_quantide ls -la /home/quantide/notebooks/courseware

它还调查了这几个容器是否都连接在同一个 course 网络:

1
docker network inspect course

然后再检查这两个容器中,jupyter 是否在运行(这个动作有点没有必要,如果是人工来做,这样做会累死驴)

1
2
docker exec course_l24_quantide ps aux | grep jupyter
docker exec course_fa_quantide ps aux | grep jupyter

然后它检查了 notebook 文件名:

1
2
docker exec course_l24_quantide find /home/quantide/notebooks/courseware -name "*.ipynb" | head -5
docker exec course_fa_quantide find /home/quantide/notebooks/courseware -name "*.ipynb" | head -5

这些都平平无奇。接下来就是比较智能的地方了,它开始怀疑 url 与文件名不匹配!

1
curl -s http://*/api/academy/home/1 | python -m json.tool

运行的这个命令我是有点看不懂了,果然是编程奇才。

然后又看了一遍文件名:

1
docker exec course_l24_quantide find /home/quantide/notebooks/courseware -name "01.ipynb"

然后又回过头去,查看我们提供给用户的链接:

1
curl -s "http://*/api/academy/home/1" | grep -o "http://*/course/l24/quantide/lab/tree/courseware/.*ipy

把这么多信息匹配起来,显然是很费脑的事情,但 Augment 正是在人困马乏之际,发现了问题所在:

Tip

我们提供给用户的链接是 /courseware/01.ipynb,但 01.ipynb 在 24 课中并不存在;在 24 课中,存在的是 lesson01.ipynb。这是一个典型的部署问题。在因子课中,所有的文件名都只有序号,但在 24 课中,源文件画蛇添足地带上了 lesson 的前缀。

作为一个人类,我确实不容易发现这个问题,作为课程的开发者,对我来说,01.ipynb 和 lesson01.ipynb都是第一课。这也是 AI 给我上的又一课。

结论

我听到过很多关于AI编程能力很弱的问题。但用过Augment之后,我觉得事实并非如此,懂得如何运用它更为重要。

实际上,我是从2023年7月的Github Copilot用起的。当时的copilot可以很好地完成文字润色和编造单元测试数据的任务,它记得很多经典算法的精巧代码结构。后来较长时间使用通义灵码,短暂使用过 Trae (国际版)和Cursor。同样是使用 Claude 3.7模型,但我感觉Augment的能力是最强的(Augment团队认为他们用的是Claude 3.7+O1的一个集成训练版,并不是简单的Claude 3.7)。另外,则于在推广期,所以使用人数少也是一大优势 -- 它的响应速度比较快。

今天我们揭示了它的一个用法,即限制住Augment的Agent,不让它生成代码,而是让它像一个资深专家一样,深入到各个子系统,在主机与dockder容器网络之间来回穿梭,去排查一些很微不足道,但如果你不掌握系统的全部运行原理,就无法排查的问题。

如果你认为我们的provision系统有一定复杂度和实用性,也苦恼过Augment或者其它AI工具无法构建复杂应用,可以给我留言,如果有较多读者认为这些经验值得分享,我就再来写一篇。

Pandas应用案例[2]

“通过将字符串列转换为 category 类型,内存占用可减少 90% 以上;使用 itertuples 替代 iterrows,遍历速度提升 6 倍;结合 Numba 的 JIT 编译,数值计算性能可媲美 C 语言。”


1. Pandas 性能

1.1. 内存优化

使用category类型可以将字符串转换为分类变量,用整数索引代替原始值,这样可以节省内存。例如:把性别这样的重复字符串转成category,内存占用大幅减少。同时,分类类型还能提高某些操作的性能,比如排序和分组,因为内部用的是整数处理,所以可以达到优化的效果。

除此之外,也可以进行数据类型优化,比如将int64转换为更小的类型如int8或者uint8。这里需要强调检查每列的数据范围,选择合适的子类型,比如:如果数值在0到255之间就用uint8。显式指定dtype是重要的,特别是在读取数据时指定类型,避免自动推断导致内存浪费。

1.1.1. Category 类型:分类数据的终极优化方案

​核心原理 - ​内存压缩:将重复的字符串(如性别、地区、产品类别)转换为整数索引,并建立映射字典。例如,将“男/女”存储为 0/1,内存占用减少 ​90%​ 以上。 - ​性能提升:分类数据在分组(groupby)、排序(sort_values)等操作中比字符串快 ​10-100 倍,因为底层使用整数运算。

使用场景 - ​低基数数据:列的唯一值数量远小于总行数(如性别仅有 2 种,但数据量百万级)。


  • ​有序分类:如评分等级(“高/中/低”)或时间段(“早/中/晚”),可指定顺序提升分析效率。

操作方法

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pandas as pd
import numpy as np

# 模拟金融数据:10万条交易记录
dates = pd.date_range('2025-01-01', periods=100000, freq='T')  # 分钟级交易

df = pd.DataFrame({
    'trade_type': np.random.choice(['buy', 'sell', 'cancel'], 
    size=100000),  # 交易类型
    'symbol': np.random.choice(['AAPL', 'MSFT', 'GOOGL', 'TSLA'], 
    size=100000),  # 股票代码
    'client_type': np.random.choice(['retail', 'institution', 'vip'], 
    size=100000),  # 客户类型
    'amount': np.random.uniform(1000, 1e6, size=100000)#交易金额
}, index=dates)

# 优化前内存占用
print("优化前内存:", df.memory_usage(deep=True).sum() / 1024**2, "MB")

# 转换为Category类型
cat_cols = ['trade_type', 'symbol', 'client_type']
df[cat_cols] = df[cat_cols].astype('category')

# 优化后内存对比
print("优化后内存:", df.memory_usage(deep=True).sum() / 1024**2, "MB")

优化前内存: 19.291857719421387 MB

优化后内存: 1.8129425048828125 MB (减少了90.6%)


Tip

定期检查内存使用情况,比如用 memory_usage 方法,来评估优化效果。

​金融场景适用字段: - ​交易类型:如 buy/sell(证券买卖)order_type(限价单/市价单) - ​资产类别:如 stockbondETF - ​客户等级:如 VIP普通机构 - 地域分类:如 CNUSHK(交易市场归属)

当列的唯一值较少且重复较多时,使用category效果最好。例如性别、地区代码等。如果分类变量的类别数量远小于总行数,转换后的内存节省会更明显。注意category类型不适合频繁变更类别的情况,这可能增加计算开销。另外,使用pd.Categorical或者cut函数创建分类数据需要注意处理缺失值的问题,因为category类型不支持NaN,所以在转换前需要处理缺失值。

1.1.2. 紧凑数据类型:精准狙击内存浪费

​数值类型优化 - ​整数类型:根据数值范围选择最小子类型:

1
2
# 检查范围后转换
df['age'] = df['age'].astype('uint8')  # 0-255 范围
- 浮点类型:优先使用 float32(精度足够时),内存减少 ​50%​ .


​布尔类型优化 将仅有 True/False 的列转换为 bool 类型:

1
    df['is_active'] = df['is_active'].astype('bool')

​时间类型优化 使用 datetime64[ns] 而非 object 存储日期,内存减少 ​75%​ 且支持时间序列运算。

金融数据常包含以下高优化价值字段: - ​离散型分类字段:交易类型(buy/sell)、证券代码(AAPL/TSLA)、客户等级(VIP/普通) - ​数值型字段:交易金额(float64)、持仓量(int64)、时间戳(object) - ​状态标识字段:是否盘后交易(True/False)、风险标记(high/medium/low)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import pandas as pd
import numpy as np

# 生成10万条模拟交易数据
dates = pd.date_range('2025-01-01', periods=100000, freq='T')  # 分钟级时间戳
df = pd.DataFrame({
    'trade_type': np.random.choice(['buy', 'sell', 'cancel'], size=100000),
    'symbol': np.random.choice(['AAPL', 'MSFT', 'GOOGL', 'TSLA'], size=100000),
    'client_level': np.random.choice(['VIP', '普通', '机构'], size=100000),
    'amount': np.random.uniform(1000, 1e6, size=100000),
    'position': np.random.randint(1, 10000, size=100000)
}, index=dates)

print("优化前内存:", df.memory_usage(deep=True).sum() / 1024**2, "MB")

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# 转换分类类型
cat_cols = ['trade_type', 'symbol', 'client_level']
df[cat_cols] = df[cat_cols].astype('category')

# 查看内存优化效果
print("优化后内存:", df.memory_usage(deep=True).sum() / 1024**2, "MB")

# 压缩数值类型
df['amount'] = df['amount'].astype('float32')  # 金额压缩为32位浮点
df['position'] = df['position'].astype('int16')  # 持仓量压缩为16位整数

# 时间戳优化(假设原始数据为字符串)
df['trade_time'] = pd.to_datetime(df.index)  # 转为datetime64[ns]

# 最终内存对比
print("最终内存:", df.memory_usage(deep=True).sum() / 1024**2, "MB")

优化前内存: 21.358366012573242 MB

优化后内存: 2.575934410095215 MB

最终内存: 2.385199546813965 MB

1.1.3. 高频交易场景综合优化
  1. ​分块读取+类型预定义
    1
    2
    3
    4
    5
    # 读取1GB级交易日志时预定义类型
    dtypes = {
        'symbol': 'category',
        'amount': 'float32',
        'position': 'int16',
    

1
2
3
4
5
6
'trade_type': 'category'
}
chunks = pd.read_csv('trade_log.csv', chunksize=100000, dtype=dtypes)
processed_chunks = [chunk.groupby('symbol')['amount'].sum() 
for chunk in chunks]
final_result = pd.concat(processed_chunks)
  1. ​分组统计加速
    1
    2
    3
    # 按证券代码统计交易量(提速5倍)
    df['symbol'] = df['symbol'].cat.add_categories(['UNKNOWN'])  # 处理新增代码
    trade_volume = df.groupby('symbol', observed=True)['position'].sum()
    
1.1.4. 进阶技巧
  1. ​有序分类(风险等级分析)​

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    from pandas.api.types import CategoricalDtype
    
    # 定义有序风险等级[5](@ref)
    risk_order = CategoricalDtype(
        categories=['low', 'medium', 'high'], 
        ordered=True
    )
    df['risk_level'] = df['risk_level'].astype(risk_order)
    
    # 筛选高风险交易(提速10倍)
       high_risk_trades = df[df['risk_level'] > 'medium']
    

  2. 布尔类型压缩(盘后交易标记)​


1
2
3
4
# 生成盘后交易标记(内存减少87%)[4](@ref)
df['is_after_hours'] = df['trade_time'].apply(
    lambda x: x.hour < 9 or x.hour > 16
).astype('bool')

Warning

  • ​动态类别管理:新增证券代码时需调用 df['symbol'].cat.add_categories(['NVDA'])
  • 数值溢出风险:持仓量若超过 int16 范围(-32768~32767),需改用 int32
  • 时间序列分析:datetime 类型支持高效时间窗口计算(如 .rolling('30T'))

通过上述方法,可在高频交易分析、客户行为画像等场景中实现 ​内存减少80%+分组操作提速5-10倍 的显著优化效果。对于超大规模数据集(如10亿级交易记录),建议结合 Dask 或 Modin 实现分布式计算。

1.2. 优化迭代

使用 itertuples 而不是 iterrows, 使用 apply 来优化迭代,先筛选再计算。itertuples 比 iterrows 快很多,因为 itertuples 返回的是命名元组,而 iterrows 返回的是 Series 对象,这会慢很多。有案例表示使用 iterrows 处理 600 万行数据需要 335 秒,而 itertuples 只需要 41 秒,快了近 6 倍。

1.2.1. 迭代方式性能对比与优化原理
  1. ​itertuples 与 iterrows 性能差异

方法 数据结构 百万行耗时 适用场景 核心优势
​iterrows 生成 (index, Series) 对 85.938s 需要行索引的简单遍历 直观易用
​ itertuples 生成命名元组 7.656s 大规模数据遍历 内存占用减少50%,速度提升6倍
​ apply 向量化函数应用 0.03s 条件逻辑较复杂的行级计算 语法简洁,自动类型优化

Notes

  • iterrows 每次迭代生成 Series 对象,触发内存分配和类型检查(面向对象开销)
  • itertuples 返回轻量级 namedtuple,直接通过属性访问数据(C语言层级优化)
  1. ​apply 函数的优化机制
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    # 示例:计算股票交易费用(佣金率分档)
    def calc_fee(row):
        if row['volume'] > 10000:
            return row['amount'] * 0.0002
        elif row['volume'] > 5000:
            return row['amount'] * 0.0003
        else:
            return row['amount'] * 0.0005
    
    # 优化点:使用 axis=1 按行应用
    df['fee'] = df.apply(calc_fee, axis=1)  # 比循环快3倍
    
1.2.2. 金融数据综合优化案例
  1. 生成模拟高频交易数据
    1
    2
    3
    # 生成100万条股票交易记录(含时间戳、代码、价格、成交量)
    dates = pd.date_range('2025-03-28 09:30', periods=1_000_000, freq='S')
    symbols = ['AAPL', 'MSFT', 'GOOG', 'AMZN', 'TSLA']
    

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
df = pd.DataFrame({
    'symbol': np.random.choice(symbols, 1_000_000),
    'price': np.random.uniform(50, 500, 1_000_000).round(2),
    'volume': np.random.randint(100, 50_000, 1_000_000),
    'trade_type': np.random.choice(['buy', 'sell'], 1_000_000)
}, index=dates)

print("优化前内存:", df.memory_usage(deep=True).sum() / 1024**2, "MB")

# 内存优化:分类列转换
df['symbol'] = df['symbol'].astype('category')  # 内存减少85%
df['trade_type'] = df['trade_type'].astype('category')

print("优化后内存:", df.memory_usage(deep=True).sum() / 1024**2, "MB")

优化前内存: 138.75994682312012 MB

优化后内存: 24.796205520629883 MB

  1. itertuples 实战:计算交易金额
    1
    2
    3
    4
    5
    6
    7
    8
    # 传统 iterrows 写法(避免使用!)
    import time
    t1 = time.time()
    total_amount = 0
    for idx, row in df.iterrows():  # 预估耗时85秒
        total_amount += row['price'] * row['volume']
    t2 = time.time()
    print("传统 iterrows 写法:",t2-t1,"s")
    

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# 优化后 itertuples 写法
total_amount = 0
for row in df.itertuples():  # 耗时约7秒
    total_amount += row.price * row.volume
t3 = time.time()
print("优化后 itertuples 写法:",t3-t2,"s")

# 终极优化:向量化计算(推荐!)
df['amount'] = df['price'] * df['volume']  # 耗时0.03秒
t4 = time.time()
print("终极优化:向量化计算:",t4-t3,"s")

传统 iterrows 写法: 85.93825674057007 s

优化后 itertuples 写法: 7.655602216720581 s

终极优化:向量化计算: 0.032360076904296875 s

  1. ​apply 实战:计算波动率因子
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    def volatility_factor(row):  # 定义波动率计算函数
        if row['volume'] > 20000:
            return row['price'] * 0.015
        elif (row['volume'] > 10000) & (row['trade_type'] == 'buy'):
            return row['price'] * 0.010
        else:
            return row['price'] * 0.005
    # 应用优化
    t5 = time.time()
    df['vol_factor'] = df.apply(volatility_factor, axis=1)  # 耗时约3秒
    t6 = time.time()
    print("定义波动率计算函数:",t6-t5,"s")
    

定义波动率计算函数: 24.482948064804077 s


  1. 先筛选再计算策略
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    # 非交易时段数据过滤(先筛选)
    market_hours = df.between_time('09:30', '16:00')  # 减少30%数据量
    
    # 仅处理大额交易(金额>100万)
    large_trades = market_hours[market_hours['amount'] > 1_000_000]
    
    # 分块处理(内存优化)
    t7 = time.time()
    chunks = (large_trades.groupby('symbol')
                        .apply(lambda x: x['amount'].mean())
                        .reset_index(name='avg_large_trade'))
    t8 = time.time()
    
    print("先筛选再计算策略:",t8-t7,"s")
    

先筛选再计算策略: 0.044037818908691406 s

apply可以利用内部优化,比循环更快,但不如矢量化操作。

1.2.3. 性能对比与最佳实践

Tip

最佳实践优先级: ​1. 向量化运算 > 2. ​itertuples > 3. ​apply > 4. ​iterrows - 优先使用 df['col'] = df['col1'] * df['col2'] 形式 - 复杂逻辑用 np.where() 或 pd.cut() 替代循环

1.2.4. 注意事项

  1. 数据预处理

    • 将时间戳设为索引 df.set_index('timestamp', inplace=True)
    • 数值列转换为最小类型: df['volume'] = df['volume'].astype('int32')
  2. 避免链式索引

1
2
3
4
5
# 错误写法(触发警告)
df[df['symbol'] == 'AAPL']['price'] = 200  

# 正确写法
df.loc[df['symbol'] == 'AAPL', 'price'] = 200  # 效率提升30%
  1. ​内存管理
    • 分块读取: pd.read_csv('trades.csv', chunksize=100000)
    • 及时删除中间变量: del temp_df 释放内存

完整代码示例可通过 Jupyter Notebook 运行测试,建议使用金融高频交易数据集(如TAQ数据)验证优化效果。对于超大规模数据(>1亿行),推荐结合 Dask 或 Modin 实现分布式计算。

1.3. 使用numpy和numba

1.3.1. Numba核心原理与优势

Numba 是 Python 的即时(JIT)编译器,通过将 Python 函数编译为机器码,显著提升计算效率,尤其适合数值计算和 Numpy 数组操作。


  • ​即时编译:通过 @jit 装饰器自动优化函数,消除 Python 解释器开销。
  • ​并行加速:使用 parallel=True 和 prange 实现多线程并行计算。
  • GPU支持:通过 @cuda.jit 将计算任务卸载到 GPU,适用于超大规模数据处理。
1.3.2. 金融数据处理优化案例
  1. ​计算股票收益率波动率(Numba加速)​
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    import numpy as np
    from numba import jit
    
    # 生成金融数据:100万条股票价格序列
    np.random.seed(42)
    prices = np.random.normal(100, 5, 1_000_000).cumsum()
    
    # 传统Python实现
    def calc_volatility(prices):
        returns = np.zeros(len(prices)-1)
        for i in range(len(prices)-1):
            returns[i] = (prices[i+1] - prices[i]) / prices[i]
        return np.std(returns) * np.sqrt(252)
    
    # Numba优化实现
    @jit(nopython=True)
    def calc_volatility_numba(prices):
        returns = np.zeros(len(prices)-1)
        for i in range(len(prices)-1):
            returns[i] = (prices[i+1] - prices[i]) / prices[i]
        return np.std(returns) * np.sqrt(252)
    
    # 性能对比
    %timeit calc_volatility(prices)    # 约 920 ms
    %timeit calc_volatility_numba(prices)  # 约 7.3 ms
    

921 ms ± 87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

7.27 ms ± 183 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

  1. 蒙特卡洛期权定价(并行计算)​
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    from numba import njit, prange
    
    @njit(parallel=True)
    def monte_carlo_pricing(S0, K, r, sigma, T, n_simulations):
        payoffs = np.zeros(n_simulations)
        for i in prange(n_simulations):
            ST = S0 * np.exp((r - 0.5*sigma**2)*T + sigma*np.sqrt(T)*np.random.normal())
            payoffs[i] = max(ST - K, 0)
        return np.exp(-r*T) * np.mean(payoffs)
    
    # 参数设置
    params = (100, 105, 0.05, 0.2, 1, 1_000_000)
    result = monte_carlo_pricing(*params)  # 约 320 ms(比纯Python快35倍)
    
1.3.3. 关键优化策略
  1. 数据类型特化 强制指定输入类型避免动态检查:
    1
    2
    3
    @jit(nopython=True, fastmath=True)
    def vec_dot(a: np.ndarray, b: np.ndarray) -> float:
        return np.dot(a, b)
    

  1. 内存预分配

    1
    2
    3
    4
    5
    6
    @jit(nopython=True)
    def moving_average(data, window):
        ma = np.empty(len(data) - window + 1)
        for i in range(len(ma)):
            ma[i] = np.mean(data[i:i+window])
        return ma
    

  2. ​避免Python对象 在 Numba 函数中禁用 Python 对象(nopython=True),确保全程机器码执行。

Note

最佳实践 - 优先使用 @njit(等价于 @jit(nopython=True)) - 对大循环使用 prange 替代 range 实现并行 - 对 np.ufunc 函数进行二次加速(如 np.sqrt、np.exp) - 避免在 JIT 函数中混合使用 Python 原生类型与 Numpy 类型

1.3.4. 扩展应用
  1. 与Pandas结合
    1
    2
    3
    @jit
    def pandas_apply_optimized(df: pd.DataFrame):
        return df['price'].values * df['volume'].values  # 直接操作Numpy数组
    

  1. GPU加速(CUDA)​
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    from numba import cuda
    
    @cuda.jit
    def cuda_matmul(A, B, C):
        i, j = cuda.grid(2)
        if i < C.shape[0] and j < C.shape[1]:
            tmp = 0.0
            for k in range(A.shape[1]):
                tmp += A[i, k] * B[k, j]
            C[i, j] = tmp
    

Tip

注意事项: - ​编译开销:首次运行 JIT 函数会有编译耗时,后续调用直接使用缓存 - 调试限制:Numba 函数不支持 pdb 断点调试,需通过 print 输出中间值 - 兼容性:部分 Numpy 高级功能(如 np.linalg.svd)在 Numba 中受限

通过合理运用 Numpy 的向量化操作与 Numba 的 JIT 编译,可在金融量化分析、高频交易等场景实现 ​C 语言级性能,同时保持 Python 的开发效率。建议结合 %%timeit 和 Numba 的 cache=True 参数持续优化热点代码。


Pandas应用案例[3]

“Modin 通过多核并行加速 Pandas 操作,读取 10GB CSV 文件比 Pandas 快 4-8 倍;Polars 基于 Rust 架构,内存占用仅为 Pandas 的 1/3;Dask 则支持分布式计算,轻松处理 TB 级数据。”


1.4. 使用eval或者query

关于 query 方法,我它类似于SQL的where子句,允许用字符串表达式,这样代码更简洁。比如df.query('Q1 > Q2 > 90'),还支持用@符号引入外部变量。比如计算平均分后筛选高于平均分的数据。同时,eval方法类似,但返回布尔索引,需要配合df[]使用,例如:df[df.eval("Q1 > 90 > Q3 >10")]

isin 方法,用于筛选某列的值是否在指定列表中。例如,用 b1["类别"].isin(["能源","电器"]) 来筛选类别列中的值。此外,还可以结合多个条件,例如:df[df['ucity'].isin(['广州市','深圳'])]

1.4.1. query() 函数:SQL风格的条件筛选
  1. 核心语法

    1
    df.query('表达式')  # 表达式需用引号包裹,支持逻辑运算符和列名直接引用
    

  2. 金融场景示例

    1
    2
    3
    4
    5
    6
    7
    """案例1:筛选特定股票代码的高额交易"""
    # 筛选AAPL或TSLA股票,且金额超过100万的交易
    df.query("symbol in ['AAPL', 'TSLA'] and amount > 1e6")
    
    """案例2:动态引用外部变量"""
    avg_amount = df['amount'].mean()  # 计算平均交易金额
    df.query("amount > @avg_amount * 2")  # 筛选金额超过平均2倍的交易[3,5](@ref)
    


1
2
3
"""案例3:多条件组合"""
# 筛选2025年Q1买入且成交价高于开盘价的交易
df.query("trade_type == 'buy' and trade_date >= '2025-01-01' and price > open_price")
  1. 性能优势
  2. ​表达式优化:底层通过 numexpr 库加速计算,比传统布尔索引快30%以上
  3. ​列名处理:列名含空格或特殊字符时需用反引号包裹(如 收盘价 > 100)
1.4.2. eval() 函数:表达式生成布尔索引
  1. 核心语法

    1
    2
    mask = df.eval("表达式")  # 返回布尔数组
    df[mask]  # 用布尔索引筛选数据
    

  2. ​金融场景示例

    1
    2
    3
    4
    5
    6
    7
    8
    """案例1:计算复杂交易条件"""
    # 筛选波动率超过阈值且交易量增长的股票
    df[df.eval("(high - low)/close > 0.05 and volume > volume.shift(1)")]
    
    """案例2:动态公式计算"""
    # 筛选夏普比率高于行业平均的基金
    industry_avg = 1.2
    df[df.eval("(returns - risk_free_rate)/std_dev > @industry_avg")]
    

  3. ​与query()的区别

  4. eval() 返回布尔数组,需配合 df[] 使用;query() 直接返回筛选后的DataFrame
  5. 两者共享相同表达式引擎,性能差异可忽略,按代码简洁性选择即可

1.4.3. isin() 函数:多值匹配筛选
  1. 核心语法

    1
    df[df['列名'].isin(值列表)]  # 筛选列值存在于列表中的行
    

  2. ​金融场景示例

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    """案例1:筛选特定股票池"""
    blue_chips = ['600519.SH', '000858.SZ', '601318.SH']  # 上证50成分股
    df[df['symbol'].isin(blue_chips)]
    
    """案例2:排除ST/ST风险股"""
    risk_stocks = ['*ST长生', 'ST康美']  
    df[~df['stock_name'].isin(risk_stocks)]  # 使用~取反[2](@ref)
    
    """案例3:联合多列筛选"""
    # 筛选沪深300且行业为科技或金融的股票
    target_industries = ['Technology', 'Financials']
    df[df['index'].isin(['000300.SH']) & df['industry'].isin(target_industries)]
    

  3. 进阶用法

  4. ​字典筛选:多列联合匹配(如 df[df.isin({'symbol':'AAPL', 'exchange':'NASDAQ'})])
  5. 性能优化:对大列表(>1万元素)建议先转换为集合(set())提升速度
1.4.4. 综合性能优化策略
  1. 先筛选再计算

1
2
3
4
5
6
7
# 错误:先计算全量再筛选
df['return'] = df['close'].pct_change()  
df_filtered = df[df['volume'] > 1e6]

# 正确:先筛选减少计算量
df_filtered = df[df['volume'] > 1e6].copy()  
df_filtered['return'] = df_filtered['close'].pct_change()[6](@ref)
  1. ​避免链式操作

    1
    2
    3
    4
    # 错误:两次索引降低性能
    df[df['symbol'] == 'AAPL']['close']  
    # 正确:单次loc操作
    df.loc[df['symbol'] == 'AAPL', 'close'][3](@ref)
    

  2. ​类型优化

    1
    2
    # 将字符串列转为category提升isin速度
    df['symbol'] = df['symbol'].astype('category')[8](@ref)
    

1.4.5. 方法对比与适用场景
方法 适用场景 性能优势
query() 复杂多条件组合,需动态变量引用 表达式优化加速
eval() 生成中间布尔索引,用于后续处理 与query性能接近
isin() 快速匹配离散值列表(如股票代码) 集合加速+类型优化

实践建议: - ​高频筛选:优先用 query() 保持代码简洁


  • ​超大列表:用 isin() + 集合类型提升速度
  • ​动态计算:eval() 适合嵌入数学公式或跨列运算

1.5. Pandas 的其它替代方案

1.5.1. modin:单机多核并行加速器

一行代码,实现pandas替代,并拥有多核、不受内存限制的计算能力。

  1. ​核心原理
  2. ​并行化改造:将 Pandas 的 DataFrame 拆分为多个分区,利用多核 CPU 并行处理,底层支持 Ray 或 Dask 引擎。
  3. ​语法兼容性:仅需修改导入语句(import modin.pandas as pd),即可无缝替代原生 Pandas,支持 90% 以上常用 API。

  4. ​性能优势

  5. 读取加速:读取 10GB CSV 文件时,比 Pandas 快 4-8 倍。
  6. ​计算优化:groupby 等聚合操作在 4 核机器上提速 3-5 倍,内存占用减少 30%。
  7. ​适用场景:单机环境下处理 100MB~50GB 数据集,适合金融高频交易日志分析、用户行为数据清洗等。

  8. ​使用案例

    1
    2
    3
    # 读取大规模交易数据(并行加速)
    import modin.pandas as pd
    df = pd.read_csv("trades.csv", parse_dates=["timestamp"])
    


1
2
# 实时计算每分钟交易量
volume_by_minute = df.groupby(pd.Grouper(key="timestamp", freq="T"))["amount"].sum().compute()
  1. ​注意事项
  2. ​小数据集劣势:处理 <100MB 数据时可能比 Pandas 更慢(启动开销)。
  3. ​内存消耗:需预留 2-3 倍数据大小的内存,避免 OOM。
1.5.2. polars:Rust 驱动的极速引擎

最快的tableu解决方案

  1. ​核心原理
  2. ​Rust + Arrow 架构:基于 Rust 语言和 Apache Arrow 内存格式,支持零拷贝数据处理与 SIMD 指令优化。
  3. ​多线程与惰性执行:自动并行化计算,通过 lazy() 延迟执行并优化查询计划。

  4. ​性能优势

  5. 速度对比:同等操作比 Pandas 快 5-10 倍,1 亿行 groupby 计算仅需 11 秒(Pandas 需 187 秒)。
  6. ​内存效率:内存占用仅为 Pandas 的 1/3,支持处理内存不足时的核外计算。

  7. ​适用场景

  8. ​高频金融数据:如实时波动率计算、订单簿快照分析。
  9. ​复杂聚合:多条件统计、时间窗口滚动计算(如 VWAP)。

  10. 代码示例


1
2
3
4
5
6
7
8
9
import polars as pl
# 惰性执行优化查询
df = pl.scan_csv("market_data.csv")
result = (
   df.filter(pl.col("price") > 100)
   .groupby("symbol")
   .agg([pl.col("volume").sum(), pl.col("price").mean()])
   .collect()  # 触发计算
)

Tip

​注意事项 - ​语法差异:部分 Pandas 方法需改写(如 df[df.col > 0] → df.filter(pl.col("col") > 0))。 - ​可视化兼容性:需转换为 Pandas 或 NumPy 才能使用 Matplotlib/seaborn。

1.5.3. dask:分布式计算的瑞士军刀

分布式tableu,可运行在数千结点上

  1. ​核心原理
  2. ​分布式任务调度:将任务拆分为 DAG(有向无环图),支持单机多核或集群分布式执行。
  3. ​核外计算:通过分区处理超出内存的数据集(如 TB 级日志)。

  1. ​性能优势
  2. ​横向扩展:在 16 核机器上处理 50GB 数据比 Pandas 快 10 倍,支持扩展到千节点集群。
  3. ​兼容生态:无缝对接 XGBoost、Dask-ML 等库,支持分布式模型训练。

  4. ​适用场景

  5. ​超大规模数据:如全市场历史行情分析、社交网络图谱计算。
  6. ​ETL 流水线:多步骤数据清洗与特征工程(需依赖管理)。

  7. 实战技巧

    1
    2
    3
    4
    5
    import dask.dataframe as dd
    # 分块读取与处理
    ddf = dd.read_csv("s3://bucket/large_file_*.csv", blocksize="256MB")
    # 并行计算每支股票的年化波动率
    volatility = ddf.groupby("symbol")["return"].std().compute()
    

Tip

​注意事项 - ​调试复杂性:需用 Dask Dashboard 监控任务状态,定位数据倾斜问题。 - ​配置优化:合理设置分区大小(建议 100MB~1GB),避免调度开销。


1.5.4. 选型决策树
场景 ​推荐工具 ​理由
单机中数据(<50GB) Modin 零代码修改,快速提升现有 Pandas 脚本性能
高频计算/内存受限 Polars 极致速度与低内存消耗,适合量化交易场景
分布式/超大数据(>1TB) Dask 支持集群扩展,生态完善

​注:实际测试显示,Polars 在单机性能上全面领先,而 Dask 在分布式场景下更具优势。建议结合数据规模与硬件资源综合选择。