Spaces:
Runtime error
Runtime error
import json | |
import random | |
from collections import defaultdict, deque | |
def generate_localization_samples(n): | |
all_data = [] | |
global_index = 1 | |
def is_all_steps_connected(steps): | |
# 构建依赖图 | |
graph = defaultdict(list) | |
reverse_graph = defaultdict(list) | |
all_ids = set() | |
for step in steps: | |
step_id = step["id"] | |
inputs = step["inputs"] | |
all_ids.add(step_id) | |
for inp in inputs: | |
if isinstance(inp, int): # 如果引用了前一个 step | |
graph[inp].append(step_id) | |
reverse_graph[step_id].append(inp) | |
# 最后一个 step ID | |
print(steps) | |
last_id = steps[-1]["id"] | |
# 从最后一个 step 开始反向遍历,看能否覆盖所有 step | |
visited = set() | |
queue = deque([last_id]) | |
while queue: | |
curr = queue.popleft() | |
visited.add(curr) | |
for parent in reverse_graph[curr]: | |
if parent not in visited: | |
queue.append(parent) | |
return all_ids.issubset(visited) | |
while len(all_data) < n: | |
sample = {"index": global_index, "instruction": "", "steps": []} | |
num_locations = random.randint(1, 3) | |
locations = [f"LOC_{i+1}" for i in range(num_locations)] | |
used_locations = set() | |
steps = [] | |
current_id = 1 | |
all_refs = locations.copy() # step inputs can be LOCs or previous step IDs | |
step_definitions = [] | |
num_steps = random.randint(2, 5) | |
for _ in range(num_steps): | |
func = random.choice(["Relative", "Azimuth", "Between"]) | |
if func in ["Relative", "Azimuth"]: | |
base = random.choice(all_refs) | |
if isinstance(base, str): | |
used_locations.add(base) | |
if func == "Relative": | |
direction = random.choice([ | |
"north", "south", "east", "west", | |
"northeast", "northwest", "southeast", "southwest" | |
]) | |
distance = f"{random.randint(1, 10)} km" | |
step_definitions.append({ | |
"id": current_id, | |
"function": "Relative", | |
"inputs": [base, direction, distance] | |
}) | |
else: | |
angle = f"{random.randint(0, 359)}°" | |
distance = f"{random.randint(1, 10)} km" | |
step_definitions.append({ | |
"id": current_id, | |
"function": "Azimuth", | |
"inputs": [base, angle, distance] | |
}) | |
all_refs.append(current_id) | |
current_id += 1 | |
elif func == "Between" and len(all_refs) >= 2: | |
base1, base2 = random.sample(all_refs, 2) | |
for b in (base1, base2): | |
if isinstance(b, str): | |
used_locations.add(b) | |
step_definitions.append({ | |
"id": current_id, | |
"function": "Between", | |
"inputs": [base1, base2] | |
}) | |
all_refs.append(current_id) | |
current_id += 1 | |
if len(step_definitions) == 0: | |
continue # 无有效步骤,跳过重新生成 | |
all_locs_used = all(loc in used_locations for loc in locations) | |
steps_connected = is_all_steps_connected(step_definitions) | |
if all_locs_used and steps_connected: | |
sample["steps"] = step_definitions | |
all_data.append(sample) | |
global_index += 1 | |
# 否则重新生成 | |
return all_data | |
def write_custom_json(data, filename): | |
def format_step(step): | |
inputs = json.dumps(step["inputs"], ensure_ascii=False) | |
return f'{{"id": {step["id"]}, "function": "{step["function"]}", "inputs": {inputs}}}' | |
with open(filename, "w", encoding="utf-8") as f: | |
f.write("[\n") | |
for i, item in enumerate(data): | |
f.write(" {\n") | |
f.write(f' "index": {item["index"]},\n') | |
f.write(' "instruction": "",\n') | |
f.write(' "steps": [\n') | |
step_lines = [f" {format_step(step)}" for step in item["steps"]] | |
f.write(",\n".join(step_lines)) | |
f.write("\n ]\n") | |
f.write(" }" + (",\n" if i < len(data) - 1 else "\n")) | |
f.write("]\n") | |
# 运行 | |
if __name__ == "__main__": | |
samples = generate_localization_samples(100) | |
write_custom_json(samples, "localization_samples.json") | |
print("✅ Saved to localization_samples.json with all steps contributing.") | |