5. 用 args 和 *kwargs 自定义聚合函数

  1. # 用inspect模块查看groupby对象的agg方法的签名
  2. In[31]: college = pd.read_csv('data/college.csv')
  3. grouped = college.groupby(['STABBR', 'RELAFFIL'])
  4. In[32]: import inspect
  5. inspect.signature(grouped.agg)
  6. Out[32]: <Signature (arg, *args, **kwargs)>

如何做

  1. # 自定义一个返回去本科生人数在1000和3000之间的比例的函数
  2. In[33]: def pct_between_1_3k(s):
  3. return s.between(1000, 3000).mean()
  4. # 用州和宗教分组,再聚合
  5. In[34]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between_1_3k).head(9)
  6. Out[34]:
  7. STABBR RELAFFIL
  8. AK 0 0.142857
  9. 1 0.000000
  10. AL 0 0.236111
  11. 1 0.333333
  12. AR 0 0.279412
  13. 1 0.111111
  14. AS 0 1.000000
  15. AZ 0 0.096774
  16. 1 0.000000
  17. Name: UGDS, dtype: float64
  1. # 但是这个函数不能让用户自定义上下限,再新写一个函数
  2. In[35]: def pct_between(s, low, high):
  3. return s.between(low, high).mean()
  4. # 使用这个自定义聚合函数,并传入最大和最小值
  5. In[36]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between, 1000, 10000).head(9)
  6. Out[36]:
  7. STABBR RELAFFIL
  8. AK 0 0.428571
  9. 1 0.000000
  10. AL 0 0.458333
  11. 1 0.375000
  12. AR 0 0.397059
  13. 1 0.166667
  14. AS 0 1.000000
  15. AZ 0 0.233871
  16. 1 0.111111
  17. Name: UGDS, dtype: float64

原理

  1. # 显示指定最大和最小值
  2. In[37]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between, high=10000, low=1000).head(9)
  3. Out[37]:
  4. STABBR RELAFFIL
  5. AK 0 0.428571
  6. 1 0.000000
  7. AL 0 0.458333
  8. 1 0.375000
  9. AR 0 0.397059
  10. 1 0.166667
  11. AS 0 1.000000
  12. AZ 0 0.233871
  13. 1 0.111111
  14. Name: UGDS, dtype: float64
  1. # 也可以关键字参数和非关键字参数混合使用,只要非关键字参数在后面
  2. In[38]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(pct_between, 1000, high=10000).head(9)
  3. Out[38]:
  4. STABBR RELAFFIL
  5. AK 0 0.428571
  6. 1 0.000000
  7. AL 0 0.458333
  8. 1 0.375000
  9. AR 0 0.397059
  10. 1 0.166667
  11. AS 0 1.000000
  12. AZ 0 0.233871
  13. 1 0.111111
  14. Name: UGDS, dtype: float64

更多

  1. # Pandas不支持多重聚合时,使用参数
  2. In[39]: college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(['mean', pct_between], low=100, high=1000)
  3. ---------------------------------------------------------------------------
  4. TypeError Traceback (most recent call last)
  5. <ipython-input-39-3e3e18919cf9> in <module>()
  6. ----> 1 college.groupby(['STABBR', 'RELAFFIL'])['UGDS'].agg(['mean', pct_between], low=100, high=1000)
  7. /Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in aggregate(self, func_or_funcs, *args, **kwargs)
  8. 2871 if hasattr(func_or_funcs, '__iter__'):
  9. 2872 ret = self._aggregate_multiple_funcs(func_or_funcs,
  10. -> 2873 (_level or 0) + 1)
  11. 2874 else:
  12. 2875 cyfunc = self._is_cython_func(func_or_funcs)
  13. /Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in _aggregate_multiple_funcs(self, arg, _level)
  14. 2944 obj._reset_cache()
  15. 2945 obj._selection = name
  16. -> 2946 results[name] = obj.aggregate(func)
  17. 2947
  18. 2948 if isinstance(list(compat.itervalues(results))[0],
  19. /Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in aggregate(self, func_or_funcs, *args, **kwargs)
  20. 2878
  21. 2879 if self.grouper.nkeys > 1:
  22. -> 2880 return self._python_agg_general(func_or_funcs, *args, **kwargs)
  23. 2881
  24. 2882 try:
  25. /Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in _python_agg_general(self, func, *args, **kwargs)
  26. 852
  27. 853 if len(output) == 0:
  28. --> 854 return self._python_apply_general(f)
  29. 855
  30. 856 if self.grouper._filter_empty_groups:
  31. /Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in _python_apply_general(self, f)
  32. 718 def _python_apply_general(self, f):
  33. 719 keys, values, mutated = self.grouper.apply(f, self._selected_obj,
  34. --> 720 self.axis)
  35. 721
  36. 722 return self._wrap_applied_output(
  37. /Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in apply(self, f, data, axis)
  38. 1800 # group might be modified
  39. 1801 group_axes = _get_axes(group)
  40. -> 1802 res = f(group)
  41. 1803 if not _is_indexed_like(res, group_axes):
  42. 1804 mutated = True
  43. /Users/Ted/anaconda/lib/python3.6/site-packages/pandas/core/groupby.py in <lambda>(x)
  44. 840 def _python_agg_general(self, func, *args, **kwargs):
  45. 841 func = self._is_builtin_func(func)
  46. --> 842 f = lambda x: func(x, *args, **kwargs)
  47. 843
  48. 844 # iterate through "columns" ex exclusions to populate output dict
  49. TypeError: pct_between() missing 2 required positional arguments: 'low' and 'high'
  1. # 用闭包自定义聚合函数
  2. In[40]: def make_agg_func(func, name, *args, **kwargs):
  3. def wrapper(x):
  4. return func(x, *args, **kwargs)
  5. wrapper.__name__ = name
  6. return wrapper
  7. my_agg1 = make_agg_func(pct_between, 'pct_1_3k', low=1000, high=3000)
  8. my_agg2 = make_agg_func(pct_between, 'pct_10_30k', 10000, 30000)['UGDS'].agg(pct_between, 1000, high=10000).head(9)
  9. Out[41]:

5. 用 *args 和 **kwargs 自定义聚合函数 - 图1