@@ -156,23 +156,22 @@ pred_df = pipeline.predict_df(
156156 return [ installSnippet , exampleSnippet ] ;
157157} ;
158158
159- export const contexttab = ( ) : string [ ] => {
160- const installSnippet = `pip install git+https://github.com/SAP-samples/contexttab ` ;
159+ export const sap_rpt_one_oss = ( ) : string [ ] => {
160+ const installSnippet = `pip install git+https://github.com/SAP-samples/sap-rpt-1-oss ` ;
161161
162162 const classificationSnippet = `# Run a classification task
163163from sklearn.datasets import load_breast_cancer
164164from sklearn.metrics import accuracy_score
165165from sklearn.model_selection import train_test_split
166166
167- from contexttab import ConTextTabClassifier
167+ from sap_rpt_oss import SAP_RPT_OSS_Classifier
168168
169169# Load sample data
170170X, y = load_breast_cancer(return_X_y=True)
171171X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
172172
173- # Initialize a classifier
174- # You can omit checkpoint and checkpoint_revision to use the default model
175- clf = ConTextTabClassifier(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048)
173+ # Initialize a classifier, 8k context and 8-fold bagging gives best performance, reduce if running out of memory
174+ clf = SAP_RPT_OSS_Classifier(max_context_size=8192, bagging=8)
176175
177176clf.fit(X_train, y_train)
178177
@@ -187,8 +186,7 @@ from sklearn.datasets import fetch_openml
187186from sklearn.metrics import r2_score
188187from sklearn.model_selection import train_test_split
189188
190- from contexttab import ConTextTabRegressor
191-
189+ from sap_rpt_oss import SAP_RPT_OSS_Regressor
192190
193191# Load sample data
194192df = fetch_openml(data_id=531, as_frame=True)
@@ -198,9 +196,8 @@ y = df.target.astype(float)
198196# Train-test split
199197X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
200198
201- # Initialize the regressor
202- # You can omit checkpoint and checkpoint_revision to use the default model
203- regressor = ConTextTabRegressor(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048)
199+ # Initialize the regressor, 8k context and 8-fold bagging gives best performance, reduce if running out of memory
200+ regressor = SAP_RPT_OSS_Regressor(max_context_size=8192, bagging=8)
204201
205202regressor.fit(X_train, y_train)
206203
0 commit comments