参数调优

对于以下双均线策略,我们希望对其进行参数调优,我们可以通过命令行参数 —extra-vars 或者通过配置 extra.context_vars 传递变量到 context 对象中。

  1. from rqalpha.api import *
  2. import talib
  3.  
  4.  
  5. def init(context):
  6. context.s1 = "000001.XSHE"
  7.  
  8. context.SHORTPERIOD = 20
  9. context.LONGPERIOD = 120
  10.  
  11.  
  12. def handle_bar(context, bar_dict):
  13. prices = history_bars(context.s1, context.LONGPERIOD+1, '1d', 'close')
  14.  
  15. short_avg = talib.SMA(prices, context.SHORTPERIOD)
  16. long_avg = talib.SMA(prices, context.LONGPERIOD)
  17.  
  18. cur_position = context.portfolio.positions[context.s1].quantity
  19. shares = context.portfolio.cash / bar_dict[context.s1].close
  20.  
  21. if short_avg[-1] - long_avg[-1] < 0 and short_avg[-2] - long_avg[-2] > 0 and cur_position > 0:
  22. order_target_value(context.s1, 0)
  23.  
  24. if short_avg[-1] - long_avg[-1] > 0 and short_avg[-2] - long_avg[-2] < 0:
  25. order_shares(context.s1, shares)

通过函数调用传递参数

  1. import concurrent.futures
  2. import multiprocessing
  3. from rqalpha import run
  4.  
  5.  
  6. tasks = []
  7. for short_period in range(3, 10, 2):
  8. for long_period in range(30, 90, 5):
  9. config = {
  10. "extra": {
  11. "context_vars": {
  12. "SHORTPERIOD": short_period,
  13. "LONGPERIOD": long_period,
  14. },
  15. "log_level": "error",
  16. },
  17. "base": {
  18. "matching_type": "current_bar",
  19. "start_date": "2015-01-01",
  20. "end_date": "2016-01-01",
  21. "benchmark": "000001.XSHE",
  22. "frequency": "1d",
  23. "strategy_file": "rqalpha/examples/golden_cross.py",
  24. "accounts": {
  25. "stock": 100000
  26. }
  27. },
  28. "mod": {
  29. "sys_progress": {
  30. "enabled": True,
  31. "show": True,
  32. },
  33. "sys_analyser": {
  34. "enabled": True,
  35. "output_file": "results/out-{short_period}-{long_period}.pkl".format(
  36. short_period=short_period,
  37. long_period=long_period,
  38. )
  39. },
  40. },
  41. }
  42.  
  43. tasks.append(config)
  44.  
  45.  
  46. def run_bt(config):
  47. run(config)
  48.  
  49.  
  50. with concurrent.futures.ProcessPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
  51. for task in tasks:
  52. executor.submit(run_bt, task)

通过命令行传递参数

  1. import os
  2. import json
  3. import concurrent.futures
  4. import multiprocessing
  5.  
  6.  
  7. tasks = []
  8. for short_period in range(3, 10, 2):
  9. for long_period in range(30, 90, 5):
  10. extra_vars = {
  11. "SHORTPERIOD": short_period,
  12. "LONGPERIOD": long_period,
  13. }
  14. vars_params = json.dumps(extra_vars).encode("utf-8").decode("utf-8")
  15.  
  16. cmd = ("rqalpha run -fq 1d -f rqalpha/examples/golden_cross.py --start-date 2015-01-01 --end-date 2016-01-01 "
  17. "-o results/out-{short_period}-{long_period}.pkl --account stock 100000 --progress -bm 000001.XSHE --extra-vars '{params}' ").format(
  18. short_period=short_period,
  19. long_period=long_period,
  20. params=vars_params)
  21.  
  22. tasks.append(cmd)
  23.  
  24.  
  25. def run_bt(cmd):
  26. print(cmd)
  27. os.system(cmd)
  28.  
  29.  
  30. with concurrent.futures.ProcessPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
  31. for task in tasks:
  32. executor.submit(run_bt, task)

分析批量回测结果

  1. import glob
  2. import pandas as pd
  3.  
  4.  
  5. results = []
  6.  
  7. for name in glob.glob("results/*.pkl"):
  8. result_dict = pd.read_pickle(name)
  9. summary = result_dict["summary"]
  10. results.append({
  11. "name": name,
  12. "annualized_returns": summary["annualized_returns"],
  13. "sharpe": summary["sharpe"],
  14. "max_drawdown": summary["max_drawdown"],
  15. })
  16.  
  17. results_df = pd.DataFrame(results)
  18.  
  19. print("-" * 50)
  20. print("Sort by sharpe")
  21. print(results_df.sort_values("sharpe", ascending=False)[:10])
  22.  
  23. print("-" * 50)
  24. print("Sort by annualized_returns")
  25. print(results_df.sort_values("annualized_returns", ascending=False)[:10])