derek-thomas
		
	commited on
		
		
					Commit 
							
							·
						
						3772eaf
	
1
								Parent(s):
							
							7d5ff0e
								
Move client instantiation
Browse files- src/utilities.py +5 -3
    	
        src/utilities.py
    CHANGED
    
    | @@ -12,7 +12,6 @@ USERNAME = os.environ["USERNAME"] | |
| 12 | 
             
            OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
         | 
| 13 | 
             
            PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
         | 
| 14 |  | 
| 15 | 
            -
            client = Client("derek-thomas/nomic-embeddings")
         | 
| 16 | 
             
            logger = setup_logger(__name__)
         | 
| 17 |  | 
| 18 |  | 
| @@ -29,6 +28,9 @@ async def load_datasets(): | |
| 29 |  | 
| 30 |  | 
| 31 | 
             
            def merge_and_update_datasets(dataset, original_dataset):
         | 
|  | |
|  | |
|  | |
| 32 | 
             
                # Merge and figure out which rows need to be updated with embeddings
         | 
| 33 | 
             
                odf = original_dataset['train'].to_pandas()
         | 
| 34 | 
             
                df = dataset['train'].to_pandas()
         | 
| @@ -50,13 +52,13 @@ def merge_and_update_datasets(dataset, original_dataset): | |
| 50 | 
             
                # Iterate over the DataFrame rows where 'embedding' is None
         | 
| 51 | 
             
                for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
         | 
| 52 | 
             
                    # Update 'embedding' for the current row using our function
         | 
| 53 | 
            -
                    merged_df.at[index, 'embedding'] = update_embeddings(row['content'])
         | 
| 54 |  | 
| 55 | 
             
                dataset['train'] = Dataset.from_pandas(merged_df)
         | 
| 56 | 
             
                logger.info(f"Updated {updated_rows} rows")
         | 
| 57 | 
             
                return dataset
         | 
| 58 |  | 
| 59 |  | 
| 60 | 
            -
            def update_embeddings(content):
         | 
| 61 | 
             
                embedding = client.predict(content, api_name="/embed")
         | 
| 62 | 
             
                return np.array(embedding)
         | 
|  | |
| 12 | 
             
            OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
         | 
| 13 | 
             
            PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
         | 
| 14 |  | 
|  | |
| 15 | 
             
            logger = setup_logger(__name__)
         | 
| 16 |  | 
| 17 |  | 
|  | |
| 28 |  | 
| 29 |  | 
| 30 | 
             
            def merge_and_update_datasets(dataset, original_dataset):
         | 
| 31 | 
            +
                # Get client
         | 
| 32 | 
            +
                client = Client("derek-thomas/nomic-embeddings")
         | 
| 33 | 
            +
             | 
| 34 | 
             
                # Merge and figure out which rows need to be updated with embeddings
         | 
| 35 | 
             
                odf = original_dataset['train'].to_pandas()
         | 
| 36 | 
             
                df = dataset['train'].to_pandas()
         | 
|  | |
| 52 | 
             
                # Iterate over the DataFrame rows where 'embedding' is None
         | 
| 53 | 
             
                for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
         | 
| 54 | 
             
                    # Update 'embedding' for the current row using our function
         | 
| 55 | 
            +
                    merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)
         | 
| 56 |  | 
| 57 | 
             
                dataset['train'] = Dataset.from_pandas(merged_df)
         | 
| 58 | 
             
                logger.info(f"Updated {updated_rows} rows")
         | 
| 59 | 
             
                return dataset
         | 
| 60 |  | 
| 61 |  | 
| 62 | 
            +
            def update_embeddings(content, client):
         | 
| 63 | 
             
                embedding = client.predict(content, api_name="/embed")
         | 
| 64 | 
             
                return np.array(embedding)
         | 
