In case it saves someone time…

```
def qc_strat_sampling( df, target_col, strat_col, size=10, n_bins=5, iterations=100 ) :
""" dataframe, string, string, int, int, int --> nothing (just plots)"""
# given a target column whose mean you'd like to estimate as accurately as possible
# you now want to report how good this stratified sampling strategy is
# strat_col : the column you're going to split into bins
# size : the sample size. If you have size=10 and two bins, one with 20 and the other with 80, then
# you'll take 2 samples from the first and 8 from the other
# iterations : how many times you're going to repeat the estimation experiment to generate
# the scatter plot
# n_bins ignored in the case of a non-numeric column (unique values are categories..)
means = []
l_df = len( df )
if size > l_df :
raise ValueError('size must be less than number of rows in df')
if pd.api.types.is_numeric_dtype( df[strat_col] ) :
temp_df = df.loc[:,[target_col,strat_col]]
categories = strat_col + '_cat' # used as string here
temp_df[categories] = pd.cut( x=df[strat_col], bins=n_bins, labels=range(n_bins) )
series = [ temp_df.loc[ temp_df[categories] == j, target_col] for j in range(n_bins)]
else :
categories = df[strat_col].unique() # used as list here
series = [df.loc[ df[strat_col] == uval, target_col] for uval in categories]
ser_ls = [ len(ser) for ser in series]
for i in range(iterations) :
sample = pd.concat( [ser.sample( n=max(1,round(size*ser_ls[j]/l_df)), random_state=i ) for j,ser in enumerate(series) ] )
means.append( sample.mean() )
plt.scatter( range(1,iterations+1), means )
plt.axhline( df[target_col].mean() )
plt.show()
```

WNBA 2017 dataset from Thomas de Jonghe

```
wnba = pd.read_csv('WNBA Stats.csv')
qc_strat_sampling( wnba, 'PTS', 'MIN' )
```

```
qc_strat_sampling( wnba, 'PTS', 'Pos' ) # using Player positions as the strata
```

Compared to first using value_counts on the strata column and then

```
ser1 = wnba.loc[ wnba['Games Played'] < 13, 'PTS' ]
ser2 = wnba.loc[ (wnba['Games Played'] < 23) & (wnba['Games Played'] > 12), 'PTS' ]
ser3 = wnba.loc[ wnba['Games Played'] > 22 , 'PTS' ]
means = []
for i in range( 100 ) :
s1 = ser1.sample( n=1, random_state= i )
s2 = ser2.sample( n=2, random_state= i )
s3 = ser3.sample( n=7, random_state= i )
sample = pd.concat( [s1,s2,s3] )
means.append( sample.mean() )
plt.scatter( range(1,101), means )
plt.axhline( wnba['PTS'].mean() )
```