import json import random from collections import defaultdict, deque def generate_localization_samples(n): all_data = [] global_index = 1 def is_all_steps_connected(steps): # 构建依赖图 graph = defaultdict(list) reverse_graph = defaultdict(list) all_ids = set() for step in steps: step_id = step["id"] inputs = step["inputs"] all_ids.add(step_id) for inp in inputs: if isinstance(inp, int): # 如果引用了前一个 step graph[inp].append(step_id) reverse_graph[step_id].append(inp) # 最后一个 step ID print(steps) last_id = steps[-1]["id"] # 从最后一个 step 开始反向遍历,看能否覆盖所有 step visited = set() queue = deque([last_id]) while queue: curr = queue.popleft() visited.add(curr) for parent in reverse_graph[curr]: if parent not in visited: queue.append(parent) return all_ids.issubset(visited) while len(all_data) < n: sample = {"index": global_index, "instruction": "", "steps": []} num_locations = random.randint(1, 3) locations = [f"LOC_{i+1}" for i in range(num_locations)] used_locations = set() steps = [] current_id = 1 all_refs = locations.copy() # step inputs can be LOCs or previous step IDs step_definitions = [] num_steps = random.randint(2, 5) for _ in range(num_steps): func = random.choice(["Relative", "Azimuth", "Between"]) if func in ["Relative", "Azimuth"]: base = random.choice(all_refs) if isinstance(base, str): used_locations.add(base) if func == "Relative": direction = random.choice([ "north", "south", "east", "west", "northeast", "northwest", "southeast", "southwest" ]) distance = f"{random.randint(1, 10)} km" step_definitions.append({ "id": current_id, "function": "Relative", "inputs": [base, direction, distance] }) else: angle = f"{random.randint(0, 359)}°" distance = f"{random.randint(1, 10)} km" step_definitions.append({ "id": current_id, "function": "Azimuth", "inputs": [base, angle, distance] }) all_refs.append(current_id) current_id += 1 elif func == "Between" and len(all_refs) >= 2: base1, base2 = random.sample(all_refs, 2) for b in (base1, base2): if isinstance(b, str): used_locations.add(b) step_definitions.append({ "id": current_id, "function": "Between", "inputs": [base1, base2] }) all_refs.append(current_id) current_id += 1 if len(step_definitions) == 0: continue # 无有效步骤,跳过重新生成 all_locs_used = all(loc in used_locations for loc in locations) steps_connected = is_all_steps_connected(step_definitions) if all_locs_used and steps_connected: sample["steps"] = step_definitions all_data.append(sample) global_index += 1 # 否则重新生成 return all_data def write_custom_json(data, filename): def format_step(step): inputs = json.dumps(step["inputs"], ensure_ascii=False) return f'{{"id": {step["id"]}, "function": "{step["function"]}", "inputs": {inputs}}}' with open(filename, "w", encoding="utf-8") as f: f.write("[\n") for i, item in enumerate(data): f.write(" {\n") f.write(f' "index": {item["index"]},\n') f.write(' "instruction": "",\n') f.write(' "steps": [\n') step_lines = [f" {format_step(step)}" for step in item["steps"]] f.write(",\n".join(step_lines)) f.write("\n ]\n") f.write(" }" + (",\n" if i < len(data) - 1 else "\n")) f.write("]\n") # 运行 if __name__ == "__main__": samples = generate_localization_samples(100) write_custom_json(samples, "localization_samples.json") print("✅ Saved to localization_samples.json with all steps contributing.")