Elements from groups are filtered if they do not satisfy the boolean criterion specified by func.,Return a copy of a DataFrame excluding filtered elements.,Functions that mutate the passed object can produce unexpected behavior or errors and are not supported. See Mutating with User Defined Function (UDF) methods for more details.,Each subframe is endowed the attribute ‘name’ in case you need to know which group you are working on.
>>> df = pd.DataFrame({
'A': ['foo', 'bar', 'foo', 'bar',
...'foo', 'bar'
],
...'B': [1, 2, 3, 4, 5, 6],
...'C': [2.0, 5., 8., 1., 2., 9.]
}) >>>
grouped = df.groupby('A') >>>
grouped.filter(lambda x: x['B'].mean() > 3.)
A B C
1 bar 2 5.0
3 bar 4 1.0
5 bar 6 9.0
You just need to use apply
on the groupby
object. I modified your example data to make this a little more clear:
import pandas from io import StringIO csv = StringIO("" "index,A,B 0, 1, 0.0 1, 1, 3.0 2, 1, 6.0 3, 2, 0.0 4, 2, 5.0 5, 2, 7.0 "" ") df = pandas.read_csv(csv, index_col = 'index') groups = df.groupby(by = ['A']) print(groups.apply(lambda g: g[g['B'] == g['B'].max()]))
Which prints:
A B
A index
1 2 1 6
2 4 2 7
EDIT: I just learned a much neater way to do this using the .transform
group by method:
def get_max_rows(df):
B_maxes = df.groupby('A').B.transform(max)
return df[df.B == B_maxes]
B_maxes
is a series which identically indexed as the original df
containing the maximum value of B
for each A
group. You can pass lots of functions to the transform method. I think once they have output either as a scalar or vector of the same length. You can even pass some strings as common function names like 'median'
.
This is slightly different to Paul H's method in that 'A' won't be an index in the result, but you can easily set that after.
import numpy as np
import pandas as pd
df_lots_groups = pd.DataFrame(np.random.rand(30000, 3), columns = list('BCD') df_lots_groups['A'] = np.random.choice(range(10000), 30000)
%
timeit get_max_rows(df_lots_groups) 100 loops, best of 3: 2.86 ms per loop
%
timeit df_lots_groups.groupby('A').apply(lambda df: df[df.B == df.B.max()]) 1 loops, best of 3: 5.83 s per loop
Here's a abstraction which allows you to select rows from groups using any valid comparison operator and any valid groupby method:
def get_group_rows(df, group_col, condition_col, func = max, comparison = '=='):
g = df.groupby(group_col)[condition_col]
condition_limit = g.transform(func)
df.query('condition_col {} @condition_limit'.format(comparison))
A few examples:
% timeit get_group_rows(df_lots_small_groups, 'A', 'B', 'max', '==')
100 loops, best of 3: 2.84 ms per loop %
timeit get_group_rows(df_lots_small_groups, 'A', 'B', 'mean', '!=')
100 loops, best of 3: 2.97 ms per loop
Here's the other example for : Filtering the rows with maximum value after groupby operation using idxmax() and .loc()
In[465]: import pandas as pd In[466]: df = pd.DataFrame({ 'sp': ['MM1', 'MM1', 'MM1', 'MM2', 'MM2', 'MM2'], 'mt': ['S1', 'S1', 'S3', 'S3', 'S4', 'S4'], 'value': [3, 2, 5, 8, 10, 1] }) In[467]: df Out[467]: mt sp value 0 S1 MM1 3 1 S1 MM1 2 2 S3 MM1 5 3 S3 MM2 8 4 S4 MM2 10 5 S4 MM2 1 # # # Here, idxmax() finds the indices of the rows with max value within groups, # # # and.loc() filters the rows using those indices: In[468]: df.loc[df.groupby(["mt"])["value"].idxmax()] Out[468]: mt sp value 0 S1 MM1 3 3 S3 MM2 8 4 S4 MM2 10
All of these answers are good but I wanted the following:
(DataframeGroupby object) -- > filter some rows out-- > (DataframeGroupby object)
Shrug, it appears that is harder and more interesting than I expected. So this one liner accomplishes what I wanted but it's probably not the most efficient way :)
gdf.apply(lambda g: g[g['team'] == 'A']).reset_index(drop = True).groupby(gdf.grouper.names)
Working Code Example:
import pandas as pd
def print_groups(gdf):
for name, g in gdf:
print('\n' + name)
print(g)
df = pd.DataFrame({
'name': ['sue', 'jim', 'ted', 'moe'],
'team': ['A', 'A', 'B', 'B'],
'fav_food': ['tacos', 'steak', 'tacos', 'steak']
})
gdf = df.groupby('fav_food')
print_groups(gdf)
steak
name team fav_food
1 jim A steak
3 moe B steak
tacos
name team fav_food
0 sue A tacos
2 ted B tacos
fgdf = gdf.apply(lambda g: g[g['team'] == 'A']).reset_index(drop = True).groupby(gdf.grouper.names)
print_groups(fgdf)
steak
name team fav_food
0 jim A steak
tacos
name team fav_food
1 sue A tacos
Feb 11, 2021 • Martin • 9 min read
import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv")
pd.options.display.max_rows = 10
df
df.groupby('day').aggregate('mean')
df.groupby('day')['total_bill'].aggregate('mean')
day Fri 17.151579 Sat 20.441379 Sun 21.410000 Thur 17.682742 Name: total_bill, dtype: float64
df.groupby('day')['total_bill'].mean()
df.groupby('day').filter(lambda x: x['total_bill'].mean() > 20)
Updated: January 7, 2022
import pandas as pd
import numpy as np
df = pd.DataFrame({
'continent': ['Asia', 'NorthAmerica', 'NorthAmerica', 'Europe', 'Europe', 'Europe', 'Asia', 'Europe', 'Asia'],
'country': ['China', 'USA', 'Canada', 'Poland', 'Romania', 'Italy', 'India', 'Germany', 'Russia'],
'GDP(trillion)': np.random.randint(1, 9, 9),
'Member_G20': np.random.choice(['Y', 'N'], 9)
})
df.groupby(['continent']).apply(lambda x: x[x['Member_G20'] == 'Y']['GDP(trillion)'].sum())
continent Asia 19 Europe 5 NorthAmerica 5 dtype: int64
grp = df.groupby(['continent'])
selected_group = grp.get_group('Europe')
selected_group
selected_group[selected_group['Member_G20'] == 'Y']