Function to check error in stratified sampling

In case it saves someone time…

def qc_strat_sampling( df, target_col, strat_col, size=10, n_bins=5, iterations=100 ) :
    """ dataframe, string, string, int, int, int --> nothing (just plots)"""
    # given a target column whose mean you'd like to estimate as accurately as possible
    # you now want to report how good this stratified sampling strategy is
    # strat_col : the column you're going to split into bins
    # size : the sample size. If you have size=10 and two bins, one with 20 and the other with 80, then
    #     you'll take 2 samples from the first and 8 from the other
    # iterations : how many times you're going to repeat the estimation experiment to generate
    #     the scatter plot
    # n_bins ignored in the case of a non-numeric column (unique values are categories..)
    means = []
    l_df = len( df )
    if size > l_df :
        raise ValueError('size must be less than number of rows in df')
    if pd.api.types.is_numeric_dtype( df[strat_col] ) :
        temp_df = df.loc[:,[target_col,strat_col]]
        categories = strat_col + '_cat'   # used as string here
        temp_df[categories] = pd.cut( x=df[strat_col], bins=n_bins, labels=range(n_bins) )
        series = [ temp_df.loc[ temp_df[categories] == j, target_col] for j in range(n_bins)]
    else :
        categories = df[strat_col].unique()   # used as list here
        series = [df.loc[ df[strat_col] == uval, target_col] for uval in categories]
    ser_ls = [ len(ser) for ser in series]
    for i in range(iterations) :
        sample = pd.concat( [ser.sample( n=max(1,round(size*ser_ls[j]/l_df)), random_state=i ) for j,ser in enumerate(series) ] )
        means.append( sample.mean() )
    plt.scatter( range(1,iterations+1), means )
    plt.axhline( df[target_col].mean() )

WNBA 2017 dataset from Thomas de Jonghe

wnba = pd.read_csv('WNBA Stats.csv')
qc_strat_sampling( wnba, 'PTS', 'MIN' )


qc_strat_sampling( wnba, 'PTS', 'Pos' )  # using Player positions as the strata


Compared to first using value_counts on the strata column and then

ser1 = wnba.loc[ wnba['Games Played'] < 13, 'PTS' ]
ser2 = wnba.loc[ (wnba['Games Played'] < 23) & (wnba['Games Played'] > 12), 'PTS' ]
ser3 = wnba.loc[ wnba['Games Played'] > 22 , 'PTS' ]
means = []
for i in range( 100 ) :
    s1 = ser1.sample( n=1, random_state= i )
    s2 = ser2.sample( n=2, random_state= i )
    s3 = ser3.sample( n=7, random_state= i )
    sample = pd.concat( [s1,s2,s3] )
    means.append( sample.mean() )

plt.scatter( range(1,101), means )
plt.axhline( wnba['PTS'].mean() )


3 Likes Do you mind providing a question link as per these guidelines. Note that I recategorized your topic and the guidelines not only apply to asking questions, but also sharing alternative solutions.


It probably belongs in Social Ryan, not Q&A…
Just sharing code

This is awesome Thank you for sharing this!

Which mission/screen in Dataquest does this function most relate to? It’ll be awesome if you add a link to that mission in your post so that learners doing that mission can discover your post more easily. :slight_smile:


This is the mission Nityesh : Choosing the right strata