DRIFT 搜索
In [1]
已复制!
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License.
In [2]
已复制!
import os
from pathlib import Path
import pandas as pd
import tiktoken
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
print(f"Entity df columns: {entity_df.columns}")
entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
full_content_embedding_store = LanceDBVectorStore(
collection_name="default-community-full_content",
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
import os from pathlib import Path import pandas as pd import tiktoken from graphrag.config.enums import ModelType from graphrag.config.models.drift_search_config import DRIFTSearchConfig from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.language_model.manager import ModelManager from graphrag.query.indexer_adapters import ( read_indexer_entities, read_indexer_relationships, read_indexer_report_embeddings, read_indexer_reports, read_indexer_text_units, ) from graphrag.query.structured_search.drift_search.drift_context import ( DRIFTSearchContextBuilder, ) from graphrag.query.structured_search.drift_search.search import DRIFTSearch from graphrag.vector_stores.lancedb import LanceDBVectorStore INPUT_DIR = "./inputs/operation dulce" LANCEDB_URI = f"{INPUT_DIR}/lancedb" COMMUNITY_REPORT_TABLE = "community_reports" COMMUNITY_TABLE = "communities" ENTITY_TABLE = "entities" RELATIONSHIP_TABLE = "relationships" COVARIATE_TABLE = "covariates" TEXT_UNIT_TABLE = "text_units" COMMUNITY_LEVEL = 2 # read nodes table to get community and degree data entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet") community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet") print(f"Entity df columns: {entity_df.columns}") entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL) # load description embeddings to an in-memory lancedb vectorstore # to connect to a remote db, specify url and port values. description_embedding_store = LanceDBVectorStore( collection_name="default-entity-description", ) description_embedding_store.connect(db_uri=LANCEDB_URI) full_content_embedding_store = LanceDBVectorStore( collection_name="default-community-full_content", ) full_content_embedding_store.connect(db_uri=LANCEDB_URI) print(f"Entity count: {len(entity_df)}") entity_df.head() relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet") relationships = read_indexer_relationships(relationship_df) print(f"Relationship count: {len(relationship_df)}") relationship_df.head() text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet") text_units = read_indexer_text_units(text_unit_df) print(f"Text unit records: {len(text_unit_df)}") text_unit_df.head()
Entity df columns: Index(['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'], dtype='object') Entity count: 18 Relationship count: 54 Text unit records: 5
Out[2]
id | human_readable_id | text | n_tokens | document_ids | entity_ids | relationship_ids | covariate_ids | |
---|---|---|---|---|---|---|---|---|
0 | 8e938693af886bfd081acbbe8384c3671446bff84a134a... | 1 | # Operation: Dulce\n\n## Chapter 1\n\nThe thru... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [745d28dd-be20-411b-85ff-1c69ca70e7b3, 9cba185... |
1 | fd1f46d32e1df6cd429542aeda3d64ddf3745ccb80f443... | 2 | , the hollow echo of the bay a stark reminder ... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [4f9b461f-5e8f-465d-9586-e2fc81787062, 0f74618... |
2 | 7296d9a1f046854d59079dc183de8a054c27c4843d2979... | 3 | differently than praise from others. This was... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [3ef1be9c-4080-4fac-99bd-c4a636248904, 8730b20... |
3 | ac72722a02ac71242a2a91fca323198d04197daf60515d... | 4 | contrast to the rigid silence enveloping the ... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [2c292047-b79a-4958-ab57-7bf7d7a22c92, 3cbd18a... |
4 | 4c277337d461a16aaf8f9760ddb8b44ef220e948a2341d... | 5 | a mask of duty.\n\nIn the midst of the descen... | 35 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [d084d615-3584-4ec8-9931-90aa6075c764, 4b84859... | [6efdc42e-69a2-47c0-97ec-4b296cd16d5e] | [db8da02f-f889-4bb5-8e81-ab2a72e380bb] |
In [3]
已复制!
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]
chat_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIChat,
model=llm_model,
max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search",
model_type=ModelType.OpenAIChat,
config=chat_config,
)
token_encoder = tiktoken.encoding_for_model(llm_model)
embedding_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIEmbedding,
model=embedding_model,
max_retries=20,
)
text_embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=ModelType.OpenAIEmbedding,
config=embedding_config,
)
api_key = os.environ["GRAPHRAG_API_KEY"] llm_model = os.environ["GRAPHRAG_LLM_MODEL"] embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"] chat_config = LanguageModelConfig( api_key=api_key, type=ModelType.OpenAIChat, model=llm_model, max_retries=20, ) chat_model = ModelManager().get_or_create_chat_model( name="local_search", model_type=ModelType.OpenAIChat, config=chat_config, ) token_encoder = tiktoken.encoding_for_model(llm_model) embedding_config = LanguageModelConfig( api_key=api_key, type=ModelType.OpenAIEmbedding, model=embedding_model, max_retries=20, ) text_embedder = ModelManager().get_or_create_embedding_model( name="local_search_embedding", model_type=ModelType.OpenAIEmbedding, config=embedding_config, )
In [4]
已复制!
def read_community_reports(
input_dir: str,
community_report_table: str = COMMUNITY_REPORT_TABLE,
):
"""Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
input_path = Path(input_dir) / f"{community_report_table}.parquet"
return pd.read_parquet(input_path)
report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
report_df,
community_df,
COMMUNITY_LEVEL,
content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)
def read_community_reports( input_dir: str, community_report_table: str = COMMUNITY_REPORT_TABLE, ): """Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path.""" input_path = Path(input_dir) / f"{community_report_table}.parquet" return pd.read_parquet(input_path) report_df = read_community_reports(INPUT_DIR) reports = read_indexer_reports( report_df, community_df, COMMUNITY_LEVEL, content_embedding_col="full_content_embeddings", ) read_indexer_report_embeddings(reports, full_content_embedding_store)
In [5]
已复制!
drift_params = DRIFTSearchConfig(
temperature=0,
max_tokens=12_000,
primer_folds=1,
drift_k_followups=3,
n_depth=3,
n=1,
)
context_builder = DRIFTSearchContextBuilder(
model=chat_model,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
token_encoder=token_encoder,
config=drift_params,
)
search = DRIFTSearch(
model=chat_model, context_builder=context_builder, token_encoder=token_encoder
)
drift_params = DRIFTSearchConfig( temperature=0, max_tokens=12_000, primer_folds=1, drift_k_followups=3, n_depth=3, n=1, ) context_builder = DRIFTSearchContextBuilder( model=chat_model, text_embedder=text_embedder, entities=entities, relationships=relationships, reports=reports, entity_text_embeddings=description_embedding_store, text_units=text_units, token_encoder=token_encoder, config=drift_params, ) search = DRIFTSearch( model=chat_model, context_builder=context_builder, token_encoder=token_encoder )
In [6]
已复制!
resp = await search.search("Who is agent Mercer?")
resp = await search.search("Who is agent Mercer?")
--------------------------------------------------------------------------- AuthenticationError Traceback (most recent call last) Cell In[6], line 1 ----> 1 resp = await search.search("Who is agent Mercer?") File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/search.py:213, in DRIFTSearch.search(self, query, conversation_history, reduce, **kwargs) 210 # Check if query state is empty 211 if not self.query_state.graph: 212 # Prime the search with the primer --> 213 primer_context, token_ct = await self.context_builder.build_context(query) 214 llm_calls["build_context"] = token_ct["llm_calls"] 215 prompt_tokens["build_context"] = token_ct["prompt_tokens"] File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/drift_context.py:199, in DRIFTSearchContextBuilder.build_context(self, query, **kwargs) 190 raise ValueError(missing_reports_error) 192 query_processor = PrimerQueryProcessor( 193 chat_model=self.model, 194 text_embedder=self.text_embedder, 195 token_encoder=self.token_encoder, 196 reports=self.reports, 197 ) --> 199 query_embedding, token_ct = await query_processor(query) 201 report_df = self.convert_reports_to_df(self.reports) 203 # Check compatibility between query embedding and document embeddings File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/primer.py:96, in PrimerQueryProcessor.__call__(self, query) 85 async def __call__(self, query: str) -> tuple[list[float], dict[str, int]]: 86 """ 87 Call method to process the query, expand it, and embed the result. 88 (...) 94 tuple[list[float], int]: List of embeddings for the expanded query and the token count. 95 """ ---> 96 hyde_query, token_ct = await self.expand_query(query) 97 log.info("Expanded query: %s", hyde_query) 98 return self.text_embedder.embed(hyde_query), token_ct File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/primer.py:70, in PrimerQueryProcessor.expand_query(self, query) 63 template = secrets.choice(self.reports).full_content # nosec S311 65 prompt = f"""Create a hypothetical answer to the following query: {query}\n\n 66 Format it to follow the structure of the template below:\n\n 67 {template}\n" 68 Ensure that the hypothetical answer does not reference new named entities that are not present in the original query.""" ---> 70 model_response = await self.chat_model.achat(prompt) 71 text = model_response.output.content 73 prompt_tokens = num_tokens(prompt, self.token_encoder) File ~/work/graphrag/graphrag/graphrag/language_model/providers/fnllm/models.py:82, in OpenAIChatFNLLM.achat(self, prompt, history, **kwargs) 70 """ 71 Chat with the Model using the given prompt. 72 (...) 79 The response from the Model. 80 """ 81 if history is None: ---> 82 response = await self.model(prompt, **kwargs) 83 else: 84 response = await self.model(prompt, history=history, **kwargs) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/openai/llm/openai_chat_llm.py:94, in OpenAIChatLLMImpl.__call__(self, prompt, stream, **kwargs) 91 if stream: 92 return await self._streaming_chat_llm(prompt, **kwargs) ---> 94 return await self._text_chat_llm(prompt, **kwargs) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/openai/services/openai_tools_parsing.py:130, in OpenAIParseToolsLLM.__call__(self, prompt, **kwargs) 127 tools = kwargs.get("tools", []) 129 if not tools: --> 130 return await self._delegate(prompt, **kwargs) 132 completion_parameters = self._add_tools_to_parameters(kwargs, tools) 134 result = await self._delegate(prompt, **completion_parameters) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/base_llm.py:144, in BaseLLM.__call__(self, prompt, **kwargs) 142 try: 143 prompt, kwargs = self._rewrite_input(prompt, kwargs) --> 144 return await self._decorated_target(prompt, **kwargs) 145 except BaseException as e: 146 stack_trace = traceback.format_exc() File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/services/json.py:78, in JsonReceiver.decorate.<locals>.invoke(prompt, **kwargs) 76 if kwargs.get("json_model") is not None or kwargs.get("json"): 77 return await this.invoke_json(delegate, prompt, kwargs) ---> 78 return await delegate(prompt, **kwargs) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/services/rate_limiter.py:75, in RateLimiter.decorate.<locals>.invoke(prompt, **args) 73 async with self._limiter.use(manifest): 74 await self._events.on_limit_acquired(manifest) ---> 75 result = await delegate(prompt, **args) 76 finally: 77 await self._events.on_limit_released(manifest) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/base_llm.py:126, in BaseLLM._decorator_target(self, prompt, **kwargs) 121 """Target for the decorator chain. 122 123 Leave signature alone as prompt, kwargs. 124 """ 125 await self._events.on_execute_llm() --> 126 output = await self._execute_llm(prompt, kwargs) 127 result = LLMOutput(output=output) 128 await self._inject_usage(result) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/openai/llm/openai_text_chat_llm.py:157, in OpenAITextChatLLMImpl._execute_llm(self, prompt, kwargs) 154 local_model_parameters = kwargs.get("model_parameters") 155 parameters = self._build_completion_parameters(local_model_parameters) --> 157 completion = await self._client.chat.completions.create( 158 messages=cast(Iterator[ChatCompletionMessageParam], messages), 159 **parameters, 160 ) 162 result = completion.choices[0].message 163 usage: LLMUsageMetrics | None = None File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/resources/chat/completions/completions.py:2000, in AsyncCompletions.create(self, messages, model, audio, frequency_penalty, function_call, functions, logit_bias, logprobs, max_completion_tokens, max_tokens, metadata, modalities, n, parallel_tool_calls, prediction, presence_penalty, reasoning_effort, response_format, seed, service_tier, stop, store, stream, stream_options, temperature, tool_choice, tools, top_logprobs, top_p, user, web_search_options, extra_headers, extra_query, extra_body, timeout) 1957 @required_args(["messages", "model"], ["messages", "model", "stream"]) 1958 async def create( 1959 self, (...) 1997 timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 1998 ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: 1999 validate_response_format(response_format) -> 2000 return await self._post( 2001 "/chat/completions", 2002 body=await async_maybe_transform( 2003 { 2004 "messages": messages, 2005 "model": model, 2006 "audio": audio, 2007 "frequency_penalty": frequency_penalty, 2008 "function_call": function_call, 2009 "functions": functions, 2010 "logit_bias": logit_bias, 2011 "logprobs": logprobs, 2012 "max_completion_tokens": max_completion_tokens, 2013 "max_tokens": max_tokens, 2014 "metadata": metadata, 2015 "modalities": modalities, 2016 "n": n, 2017 "parallel_tool_calls": parallel_tool_calls, 2018 "prediction": prediction, 2019 "presence_penalty": presence_penalty, 2020 "reasoning_effort": reasoning_effort, 2021 "response_format": response_format, 2022 "seed": seed, 2023 "service_tier": service_tier, 2024 "stop": stop, 2025 "store": store, 2026 "stream": stream, 2027 "stream_options": stream_options, 2028 "temperature": temperature, 2029 "tool_choice": tool_choice, 2030 "tools": tools, 2031 "top_logprobs": top_logprobs, 2032 "top_p": top_p, 2033 "user": user, 2034 "web_search_options": web_search_options, 2035 }, 2036 completion_create_params.CompletionCreateParams, 2037 ), 2038 options=make_request_options( 2039 extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 2040 ), 2041 cast_to=ChatCompletion, 2042 stream=stream or False, 2043 stream_cls=AsyncStream[ChatCompletionChunk], 2044 ) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/_base_client.py:1767, in AsyncAPIClient.post(self, path, cast_to, body, files, options, stream, stream_cls) 1753 async def post( 1754 self, 1755 path: str, (...) 1762 stream_cls: type[_AsyncStreamT] | None = None, 1763 ) -> ResponseT | _AsyncStreamT: 1764 opts = FinalRequestOptions.construct( 1765 method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options 1766 ) -> 1767 return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/_base_client.py:1461, in AsyncAPIClient.request(self, cast_to, options, stream, stream_cls, remaining_retries) 1458 else: 1459 retries_taken = 0 -> 1461 return await self._request( 1462 cast_to=cast_to, 1463 options=options, 1464 stream=stream, 1465 stream_cls=stream_cls, 1466 retries_taken=retries_taken, 1467 ) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/_base_client.py:1562, in AsyncAPIClient._request(self, cast_to, options, stream, stream_cls, retries_taken) 1559 await err.response.aread() 1561 log.debug("Re-raising status error") -> 1562 raise self._make_status_error_from_response(err.response) from None 1564 return await self._process_response( 1565 cast_to=cast_to, 1566 options=options, (...) 1570 retries_taken=retries_taken, 1571 ) AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: sk-proj-********************************************************************************************************************************************************zWYA. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}
In [7]
已复制!
resp.response
resp.response
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[7], line 1 ----> 1 resp.response NameError: name 'resp' is not defined
In [8]
已复制!
print(resp.context_data)
print(resp.context_data)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[8], line 1 ----> 1 print(resp.context_data) NameError: name 'resp' is not defined