回测研究

backtesting.py定义了回测引擎,下面主要介绍相关功能函数,以及回测引擎应用示例:

加载策略

把CTA策略逻辑,对应合约品种,以及参数设置(可在策略文件外修改)载入到回测引擎中。

  1. def add_strategy(self, strategy_class: type, setting: dict):
  2. """"""
  3. self.strategy_class = strategy_class
  4. self.strategy = strategy_class(
  5. self, strategy_class.__name__, self.vt_symbol, setting
  6. )

载入历史数据

负责载入对应品种的历史数据,大概有4个步骤:

  • 根据数据类型不同,分成K线模式和Tick模式;
  • 通过select().where()方法,有条件地从数据库中选取数据,其筛选标准包括:vt_symbol、 回测开始日期、回测结束日期、K线周期(K线模式下);
  • order_by(DbBarData.datetime)表示需要按照时间顺序载入数据;
  • 载入数据是以迭代方式进行的,数据最终存入self.history_data。
  1. def load_data(self):
  2. """"""
  3. self.output("开始加载历史数据")
  4.  
  5. if self.mode == BacktestingMode.BAR:
  6. s = (
  7. DbBarData.select()
  8. .where(
  9. (DbBarData.vt_symbol == self.vt_symbol)
  10. & (DbBarData.interval == self.interval)
  11. & (DbBarData.datetime >= self.start)
  12. & (DbBarData.datetime <= self.end)
  13. )
  14. .order_by(DbBarData.datetime)
  15. )
  16. self.history_data = [db_bar.to_bar() for db_bar in s]
  17. else:
  18. s = (
  19. DbTickData.select()
  20. .where(
  21. (DbTickData.vt_symbol == self.vt_symbol)
  22. & (DbTickData.datetime >= self.start)
  23. & (DbTickData.datetime <= self.end)
  24. )
  25. .order_by(DbTickData.datetime)
  26. )
  27. self.history_data = [db_tick.to_tick() for db_tick in s]
  28.  
  29. self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")

撮合成交

载入CTA策略以及相关历史数据后,策略会根据最新的数据来计算相关指标。若符合条件会生成交易信号,发出具体委托(buy/sell/short/cover),并且在下一根K线成交。

根据委托类型的不同,回测引擎提供2种撮合成交机制来尽量模仿真实交易环节:

  • 限价单撮合成交:(以买入方向为例)先确定是否发生成交,成交标准为委托价>= 下一根K线的最低价;然后确定成交价格,成交价格为委托价与下一根K线开盘价的最小值。
  • 停止单撮合成交:(以买入方向为例)先确定是否发生成交,成交标准为委托价<= 下一根K线的最高价;然后确定成交价格,成交价格为委托价与下一根K线开盘价的最大值。

下面展示在引擎中限价单撮合成交的流程:

  • 确定会撮合成交的价格;
  • 遍历限价单字典中的所有限价单,推送委托进入未成交队列的更新状态;
  • 判断成交状态,若出现成交,推送成交数据和委托数据;
  • 从字典中删除已成交的限价单。
  1. def cross_limit_order(self):
  2. """
  3. Cross limit order with last bar/tick data.
  4. """
  5. if self.mode == BacktestingMode.BAR:
  6. long_cross_price = self.bar.low_price
  7. short_cross_price = self.bar.high_price
  8. long_best_price = self.bar.open_price
  9. short_best_price = self.bar.open_price
  10. else:
  11. long_cross_price = self.tick.ask_price_1
  12. short_cross_price = self.tick.bid_price_1
  13. long_best_price = long_cross_price
  14. short_best_price = short_cross_price
  15.  
  16. for order in list(self.active_limit_orders.values()):
  17. # Push order update with status "not traded" (pending)
  18. if order.status == Status.SUBMITTING:
  19. order.status = Status.NOTTRADED
  20. self.strategy.on_order(order)
  21.  
  22. # Check whether limit orders can be filled.
  23. long_cross = (
  24. order.direction == Direction.LONG
  25. and order.price >= long_cross_price
  26. and long_cross_price > 0
  27. )
  28.  
  29. short_cross = (
  30. order.direction == Direction.SHORT
  31. and order.price <= short_cross_price
  32. and short_cross_price > 0
  33. )
  34.  
  35. if not long_cross and not short_cross:
  36. continue
  37.  
  38. # Push order udpate with status "all traded" (filled).
  39. order.traded = order.volume
  40. order.status = Status.ALLTRADED
  41. self.strategy.on_order(order)
  42.  
  43. self.active_limit_orders.pop(order.vt_orderid)
  44.  
  45. # Push trade update
  46. self.trade_count += 1
  47.  
  48. if long_cross:
  49. trade_price = min(order.price, long_best_price)
  50. pos_change = order.volume
  51. else:
  52. trade_price = max(order.price, short_best_price)
  53. pos_change = -order.volume
  54.  
  55. trade = TradeData(
  56. symbol=order.symbol,
  57. exchange=order.exchange,
  58. orderid=order.orderid,
  59. tradeid=str(self.trade_count),
  60. direction=order.direction,
  61. offset=order.offset,
  62. price=trade_price,
  63. volume=order.volume,
  64. time=self.datetime.strftime("%H:%M:%S"),
  65. gateway_name=self.gateway_name,
  66. )
  67. trade.datetime = self.datetime
  68.  
  69. self.strategy.pos += pos_change
  70. self.strategy.on_trade(trade)
  71.  
  72. self.trades[trade.vt_tradeid] = trade

计算策略盈亏情况

基于收盘价、当日持仓量、合约规模、滑点、手续费率等计算总盈亏与净盈亏,并且其计算结果以DataFrame格式输出,完成基于逐日盯市盈亏统计。

下面展示盈亏情况的计算过程

  • 浮动盈亏 = 持仓量 (当日收盘价 - 昨日收盘价) 合约规模
  • 实际盈亏 = 持仓变化量 (当时收盘价 - 开仓成交价) 合约规模
  • 总盈亏 = 浮动盈亏 + 实际盈亏
  • 净盈亏 = 总盈亏 - 总手续费 - 总滑点
  1. def calculate_pnl(
  2. self,
  3. pre_close: float,
  4. start_pos: float,
  5. size: int,
  6. rate: float,
  7. slippage: float,
  8. ):
  9. """"""
  10. self.pre_close = pre_close
  11.  
  12. # Holding pnl is the pnl from holding position at day start
  13. self.start_pos = start_pos
  14. self.end_pos = start_pos
  15. self.holding_pnl = self.start_pos * \
  16. (self.close_price - self.pre_close) * size
  17.  
  18. # Trading pnl is the pnl from new trade during the day
  19. self.trade_count = len(self.trades)
  20.  
  21. for trade in self.trades:
  22. if trade.direction == Direction.LONG:
  23. pos_change = trade.volume
  24. else:
  25. pos_change = -trade.volume
  26.  
  27. turnover = trade.price * trade.volume * size
  28.  
  29. self.trading_pnl += pos_change * \
  30. (self.close_price - trade.price) * size
  31. self.end_pos += pos_change
  32. self.turnover += turnover
  33. self.commission += turnover * rate
  34. self.slippage += trade.volume * size * slippage
  35.  
  36. # Net pnl takes account of commission and slippage cost
  37. self.total_pnl = self.trading_pnl + self.holding_pnl
  38. self.net_pnl = self.total_pnl - self.commission - self.slippage

计算策略统计指标

calculate_statistics函数是基于逐日盯市盈亏情况(DateFrame格式)来计算衍生指标,如最大回撤、年化收益、盈亏比、夏普比率等。

  1. df["balance"] = df["net_pnl"].cumsum() + self.capital
  2. df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0)
  3. df["highlevel"] = (
  4. df["balance"].rolling(
  5. min_periods=1, window=len(df), center=False).max()
  6. )
  7. df["drawdown"] = df["balance"] - df["highlevel"]
  8. df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
  9.  
  10. # Calculate statistics value
  11. start_date = df.index[0]
  12. end_date = df.index[-1]
  13.  
  14. total_days = len(df)
  15. profit_days = len(df[df["net_pnl"] > 0])
  16. loss_days = len(df[df["net_pnl"] < 0])
  17.  
  18. end_balance = df["balance"].iloc[-1]
  19. max_drawdown = df["drawdown"].min()
  20. max_ddpercent = df["ddpercent"].min()
  21.  
  22. total_net_pnl = df["net_pnl"].sum()
  23. daily_net_pnl = total_net_pnl / total_days
  24.  
  25. total_commission = df["commission"].sum()
  26. daily_commission = total_commission / total_days
  27.  
  28. total_slippage = df["slippage"].sum()
  29. daily_slippage = total_slippage / total_days
  30.  
  31. total_turnover = df["turnover"].sum()
  32. daily_turnover = total_turnover / total_days
  33.  
  34. total_trade_count = df["trade_count"].sum()
  35. daily_trade_count = total_trade_count / total_days
  36.  
  37. total_return = (end_balance / self.capital - 1) * 100
  38. annual_return = total_return / total_days * 240
  39. daily_return = df["return"].mean() * 100
  40. return_std = df["return"].std() * 100
  41.  
  42. if return_std:
  43. sharpe_ratio = daily_return / return_std * np.sqrt(240)
  44. else:
  45. sharpe_ratio = 0

统计指标绘图

通过matplotlib绘制4幅图:

  • 资金曲线图
  • 资金回撤图
  • 每日盈亏图
  • 每日盈亏分布图
  1. def show_chart(self, df: DataFrame = None):
  2. """"""
  3. if not df:
  4. df = self.daily_df
  5.  
  6. if df is None:
  7. return
  8.  
  9. plt.figure(figsize=(10, 16))
  10.  
  11. balance_plot = plt.subplot(4, 1, 1)
  12. balance_plot.set_title("Balance")
  13. df["balance"].plot(legend=True)
  14.  
  15. drawdown_plot = plt.subplot(4, 1, 2)
  16. drawdown_plot.set_title("Drawdown")
  17. drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)
  18.  
  19. pnl_plot = plt.subplot(4, 1, 3)
  20. pnl_plot.set_title("Daily Pnl")
  21. df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])
  22.  
  23. distribution_plot = plt.subplot(4, 1, 4)
  24. distribution_plot.set_title("Daily Pnl Distribution")
  25. df["net_pnl"].hist(bins=50)
  26.  
  27. plt.show()

回测引擎使用示例

  • 导入回测引擎和CTA策略
  • 设置回测相关参数,如:品种、K线周期、回测开始和结束日期、手续费、滑点、合约规模、起始资金
  • 载入策略和数据到引擎中,运行回测。
  • 计算基于逐日统计盈利情况,计算统计指标,统计指标绘图。
  1. from vnpy.app.cta_strategy.backtesting import BacktestingEngine
  2. from vnpy.app.cta_strategy.strategies.boll_channel_strategy import (
  3. BollChannelStrategy,
  4. )
  5. from datetime import datetime
  6.  
  7. engine = BacktestingEngine()
  8. engine.set_parameters(
  9. vt_symbol="IF88.CFFEX",
  10. interval="1m",
  11. start=datetime(2018, 1, 1),
  12. end=datetime(2019, 1, 1),
  13. rate=3.0/10000,
  14. slippage=0.2,
  15. size=300,
  16. pricetick=0.2,
  17. capital=1_000_000,
  18. )
  19.  
  20. engine.add_strategy(AtrRsiStrategy, {})
  21. engine.load_data()
  22. engine.run_backtesting()
  23. df = engine.calculate_result()
  24. engine.calculate_statistics()
  25. engine.show_chart()